首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

如何使用C#实现Padim算法的训练和推理

  • 25-02-19 03:41
  • 4549
  • 12746
blog.csdn.net

目录

说明

项目背景

算法实现

预处理模块——图像预处理

主要模块——训练:Resnet层信息提取

主要模块——信息处理,计算Anomaly Map

主要模块——评估

主要模块——评估:门限值的确定

主要模块——推理

写在最后

项目下载链接


说明

作者:来瓶霸王防脱发

项目地址:

https://github.com/IntptrMax/PadimSharp

原文地址:

https://blog.csdn.net/qq_30270773/article/details/143029865

项目背景

缺陷检测(Anomaly Detection)算法是一个区分正常类别与异常类别的二分类问题,但在工业场景中大多数数据都为良品,不良数据难以获取,更难枚举,所以训练一个全监督的模型是不切实际的。因此,异常检测模型通常以单类别学习的模式。Padim算法是一种十分优秀的缺陷检测算法,直接上图可以看一下这个算法的效果。

良品图片

图片

不良品图片

图片

检测效果

图片

C#是一种十分受欢迎的编程语言,这种编程语言在工业场景下使用也是十分广泛的。在一些AI领域,会在Python下将模型转化为onnx形式,通过onnxruntime加载使用,进行推理。但是在onnx形式下进行训练十分困难。很多C#开发者不太熟悉Python环境,或者某些条件下希望在纯粹的C#环境下进行深度学习的训练和使用。这个还是有一定的困难的。

目前搜索了Github和CSDN排名靠前的几十条数据,还没有Padim算法在除Python平台下的训练+推理的相关项目或资料。本文就是在C#平台实现了Padim的训练+推理过程,应该在相关领域也算是独一份了。

算法实现

Padim算法的“训练”过程其实并没有涉及到真正的训练,而是使用Resnet18算法提取关键信息加以处理,在推理时再次使用,因此“训练”过程速度非常快,这也是这个算法的优点之一。Padim算法的具体实现还请参考相关论文:PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization

https://arxiv.org/abs/2011.08785

如果论文看起来困难,还有一些大佬对该算法在Python平台下的解读,也可以参考:PaDiM 原理与代码解析

https://blog.csdn.net/ooooocj/article/details/127601035

预处理模块——图像预处理

图像预处理使用的方法比较常规,使用了缩放等方式,此处并没有使用LetterBox,也可以达到预期效果:

  1. var transformers = torchvision.transforms.Compose([
  2.     torchvision.transforms.Resize(resizeHeight,resizeWidth),
  3. torchvision.transforms.CenterCrop(cropHeight,cropWidth),
  4. torchvision.transforms.Normalize(means, stdevs)]);

主要模块——训练:Resnet层信息提取

使用Resnet模型进行推理,并提取Layer1、Layer2、Layer3层的信息,并进行了拼接(EmbeddingConcat)。注意:这里提取时使用了钩子,钩子在使用时会有资源释放,因此这里使用了比较迂回的方式记录结果

实现代码如下:

  1. public List<(string, Tensor)> Forward(Tensor input)
  2. {
  3.  List<(string, Tensor)> outputs = new List<(string, Tensor)>();
  4.  List<TempTensor> tempTensors = new List<TempTensor>();
  5.  foreach (var named_module in model.named_children())
  6.  {
  7.   string name = named_module.name;
  8.   if (name == "layer1" || name == "layer2" || name == "layer3")
  9.   {
  10.    ((Sequential)named_module.module).register_forward_hook((Module, input, output) =>
  11.    {
  12.     tempTensors.Add(new TempTensor
  13.     {
  14.      Data = output.data<float>().ToArray(),
  15.      Name = name,
  16.      Shape = output.shape,
  17.     });
  18.     return null;
  19.    });
  20.   }
  21.  }
  22.  model.forward(input);
  23.  var layer1output = tempTensors.Find(a => a.Name == "layer1");
  24.  var layer2output = tempTensors.Find(a => a.Name == "layer2");
  25.  var layer3output = tempTensors.Find(a => a.Name == "layer3");
  26.  Tensor l1 = torch.tensor(layer1output.Data, layer1output.Shape, device: input.device);
  27.  Tensor l2 = torch.tensor(layer2output.Data, layer2output.Shape, device: input.device);
  28.  Tensor l3 = torch.tensor(layer3output.Data, layer3output.Shape, device: input.device);
  29.  outputs.Add(new("layer1", l1));
  30.  outputs.Add(new("layer2", l2));
  31.  outputs.Add(new("layer3", l3));
  32.  GC.Collect();
  33.  return outputs;
  34. }
  1. private Tensor EmbeddingConcat(Tensor[] features)
  2. {
  3.  var embeddings = features[0];
  4.  for (int i = 1; i < features.Length; i++)
  5.  {
  6.   var layerEmbedding = features[i];
  7.   layerEmbedding = torch.nn.functional.interpolate(layerEmbedding, size: [embeddings.shape[2], embeddings.shape[2]], mode: InterpolationMode.Nearest);
  8.   embeddings = torch.cat([embeddings, layerEmbedding], 1);
  9.  }
  10.  return embeddings;
  11. }

主要模块——信息处理,计算Anomaly Map

这一块主要对信息进行处理,获取矩阵的mean和cov(协方差矩阵),代码如下:

  1. public Tensor ComputeAnomalyMapInternal(Tensor embedding, Tensor mean, Tensor covariance)
  2. {
  3.  var scoreMap = ComputeDistance(embedding, mean, covariance);
  4.  var upSampledScoreMap = UpSample(scoreMap);
  5.  var smoothedAnomalyMap = SmoothAnomalyMap(upSampledScoreMap);
  6.  return smoothedAnomalyMap;
  7. }
  8. public Tensor ComputeAnomalyMap(List<(string, Tensor)> outputs, Tensor mean, Tensor covariance, Tensor idx)
  9. {
  10.  Tensor embedding = GetEmbedding(outputs);
  11.  var embeddingVectors = torch.index_select(embedding, 1, idx);
  12.  return ComputeAnomalyMapInternal(embeddingVectors, mean, covariance);
  13. }

主要模块——评估

与训练过程开始部分相似,也是获取图像的Embeddings,然后利用之前获取的Cov和mean计算马氏距离,以此评估图像的异常情况。马氏距离的计算方法如下:

  1. private Tensor ComputeDistance(Tensor embedding, Tensor mean, Tensor covariance)
  2. {
  3.  long batch = embedding.shape[0];
  4.  long channel = embedding.shape[1];
  5.  long height = embedding.shape[2];
  6.  long width = embedding.shape[3];
  7.  Tensor inv_covariance = covariance.permute(2, 0, 1).inverse();
  8.  var embedding_reshaped = embedding.reshape(batch, channel, height * width);
  9.  var delta = (embedding_reshaped - mean).permute(2, 0, 1);
  10.  var distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0);
  11.  distances = distances.reshape(batch, 1, height, width);
  12.  distances = distances.clamp(0).sqrt();
  13.  return distances;
  14. }

主要模块——评估:门限值的确定

这里需要确定图像的评估门限和像素值的评估门限。如果在评估时有负向样本,这个值会更准确,如果只有正向样本也是可以的。在Python下有个precision_recall_curve包,可以计算相关参数,但是在C#下时没有的,因此在此处仍旧只能造轮子,代码如下:

  1. private (float[] precisions, float[] recalls, float[] thresholds) _precision_recall_curve_compute_single_class(Tensor yTrue, Tensor yScores, int pos_label = 1)
  2. {
  3.  var (fps, tps, thresholds) = BinaryClfCurve(yScores, yTrue, pos_label);
  4.  var precision = tps / (tps + fps);
  5.  var recall = tps / tps[-1];
  6.  var lastInd = torch.where(tps == tps[-1])[0][0].ToInt32();
  7.  int[] sl = new int[lastInd + 1];
  8.  for (int i = 0; i < sl.Length; i++)
  9.  {
  10.   sl[i] = i;
  11.  }
  12.  var reversedPrecision = precision[sl].flip(0);
  13.  var reversedRecall = recall[sl].flip(0);
  14.  var reversedThresholds = thresholds[sl].flip(0);
  15.  precision = torch.cat(new Tensor[] { reversedPrecision, torch.ones(1, dtype: precision.dtype, device: precision.device) });
  16.  recall = torch.cat(new Tensor[] { reversedRecall, torch.zeros(1, dtype: recall.dtype, device: recall.device) });
  17.  return (precision.data<float>().ToArray(), recall.data<float>().ToArray(), reversedThresholds.data<float>().ToArray());
  18. }
  19. private (Tensor fps, Tensor tps, Tensor thresholds) BinaryClfCurve(Tensor preds, Tensor target, int posLabel = 1)
  20. {
  21.  using (torch.no_grad())
  22.  {
  23.   if (preds.ndim > target.ndim)
  24.   {
  25.    preds = preds[TensorIndex.Ellipsis, 0];
  26.   }
  27.   var descScoreIndices = torch.argsort(preds, descending: true);
  28.   preds = preds[descScoreIndices];
  29.   target = target[descScoreIndices];
  30.   Tensor weight = torch.tensor(1.0f);
  31.   var distinctValueIndices = torch.nonzero(preds[1..] - preds[..^1]).squeeze();
  32.   var thresholdIdxs = torch.cat(new Tensor[] { distinctValueIndices, torch.tensor(new long[] { target.shape[0] - 1 }, device: preds.device) });
  33.   target = (target == posLabel).to_type(ScalarType.Int64);
  34.   var tps = torch.cumsum(target * weight, dim: 0)[thresholdIdxs];
  35.   Tensor fps = 1 + thresholdIdxs - tps;
  36.   return (fps, tps, preds[thresholdIdxs]);
  37.  }
  38. }

主要模块——推理

这个过程与上面过程也十分相似,正向计算出图像的Anomaly Map后,取出这个张量中最大的值,与图像的门限值进行比较,即可评估图像是否是良品。然后对这个张量中每个元素与像素门限值做对比,即可得到按像素的异常区域,以便绘制Mask和热力图。

  1. Tensor orgImg = tensors["orgImage"].clone().to(device);
  2. Tensor t = anomaly_map > pixel_threshold;
  3. anomaly_map = (anomaly_map * t).squeeze(0);
  4. anomaly_map = torchvision.transforms.functional.resize(anomaly_map, (int)orgImg.size(2), (int)orgImg.size(1));
  5. Tensor heatmapNormalized = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min());
  6. Tensor coloredHeatmap = torch.zeros([3, (int)orgImg.size(2), (int)orgImg.size(1)],device:anomaly_map.device);
  7. coloredHeatmap[0] = heatmapNormalized.squeeze(0);
  8. float alpha = 0.3f;
  9. Tensor blendedImage = (1 - alpha) * (orgImg / 255.0f) + alpha * coloredHeatmap;
  10. var imageTensor = blendedImage.clamp(0, 1).mul(255).to(ScalarType.Byte);
  11. torchvision.io.write_jpeg(imageTensor.cpu(), "result.jpg");

写在最后

使用C#开发深度学习项目,尤其是训练的项目,是一个十分困难的过程。或者说除了Python平台,训练都十分困难。C#进行深度学习训练这个方向在国内基本很少有人开展,所以能查得到的资料很少。本人十分喜爱C#这门语言,又十分喜爱深度学习,因此仅半年一直在这方面努力。遇到了很多困难,也收获了很多。

这条路走的不容易,希望能有更多人能加入进来,一起开发,一起学习。

我在Github上已经将完整的代码发布了,项目地址为:

https://github.com/IntptrMax/PadimSharp

,期待你能在Github上送我一颗小星星。在我的Github里还GGMLSharp这个项目,这个项目也是C#平台下深度学习的开发包,希望能得到你的支持。

项目下载链接

https://download.csdn.net/download/qq_30270773/89897710
天天代码码天天
微信公众号
.NET 人工智能实践
注:本文转载自blog.csdn.net的天天代码码天天的文章"https://lw112190.blog.csdn.net/article/details/143033329"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2492) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

101
推荐
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2025 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top