C# yolov8 OpenVINO+ByteTrack Demo
目录
效果
项目
代码
Form2.cs
using ByteTrack;
using OpenCvSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.Threading;
using System.Windows.Forms;
namespace yolov8_OpenVINO_Demo
{
public partial class Form2 : Form
{
public Form2()
{
InitializeComponent();
}
string imgFilter = "*.*|*.bmp;*.jpg;*.jpeg;*.tiff;*.tiff;*.png";
YoloV8 yoloV8;
Mat image;
string image_path = "";
string model_path;
string video_path = "";
string videoFilter = "*.mp4|*.mp4;";
VideoCapture vcapture;
VideoWriter vwriter;
bool saveDetVideo = false;
ByteTracker tracker;
///
/// 单图推理
///
///
///
private void button2_Click(object sender, EventArgs e)
{
if (image_path == "")
{
return;
}
button2.Enabled = false;
pictureBox2.Image = null;
textBox1.Text = "";
Application.DoEvents();
image = new Mat(image_path);
List
//绘制结果
Mat result_image = image.Clone();
foreach (DetectionResult r in detResults)
{
Cv2.PutText(result_image, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
Cv2.Rectangle(result_image, r.Rect, Scalar.Red, thickness: 2);
}
if (pictureBox2.Image != null)
{
pictureBox2.Image.Dispose();
}
pictureBox2.Image = new Bitmap(result_image.ToMemoryStream());
textBox1.Text = yoloV8.DetectTime();
button2.Enabled = true;
}
///
/// 窗体加载,初始化
///
///
///
private void Form1_Load(object sender, EventArgs e)
{
image_path = "test/dog.jpg";
pictureBox1.Image = new Bitmap(image_path);
model_path = "model/yolov8n.onnx";
yoloV8 = new YoloV8(model_path, "model/lable.txt");
}
///
/// 选择图片
///
///
///
private void button1_Click_1(object sender, EventArgs e)
{
OpenFileDialog ofd = new OpenFileDialog();
ofd.Filter = imgFilter;
if (ofd.ShowDialog() != DialogResult.OK) return;
pictureBox1.Image = null;
image_path = ofd.FileName;
pictureBox1.Image = new Bitmap(image_path);
textBox1.Text = "";
pictureBox2.Image = null;
}
///
/// 选择视频
///
///
///
private void button4_Click(object sender, EventArgs e)
{
OpenFileDialog ofd = new OpenFileDialog();
ofd.Filter = videoFilter;
ofd.InitialDirectory = Application.StartupPath + "\\test";
if (ofd.ShowDialog() != DialogResult.OK) return;
video_path = ofd.FileName;
textBox1.Text = "";
pictureBox1.Image = null;
pictureBox2.Image = null;
button3_Click(null, null);
}
///
/// 视频推理
///
///
///
private void button3_Click(object sender, EventArgs e)
{
if (video_path == "")
{
return;
}
textBox1.Text = "开始检测";
Application.DoEvents();
Thread thread = new Thread(new ThreadStart(VideoDetection));
thread.Start();
thread.Join();
textBox1.Text = "检测完成!";
}
void VideoDetection()
{
vcapture = new VideoCapture(video_path);
if (!vcapture.IsOpened())
{
MessageBox.Show("打开视频文件失败");
return;
}
tracker = new ByteTracker((int)vcapture.Fps, 200);
Mat frame = new Mat();
List
// 获取视频的fps
double videoFps = vcapture.Get(VideoCaptureProperties.Fps);
// 计算等待时间(毫秒)
int delay = (int)(1000 / videoFps);
Stopwatch _stopwatch = new Stopwatch();
if (checkBox1.Checked)
{
vwriter = new VideoWriter("out.mp4", FourCC.X264, vcapture.Fps, new OpenCvSharp.Size(vcapture.FrameWidth, vcapture.FrameHeight));
saveDetVideo = true;
}
else
{
saveDetVideo = false;
}
Cv2.NamedWindow("DetectionResult 按下ESC,退出", WindowFlags.Normal);
Cv2.ResizeWindow("DetectionResult 按下ESC,退出", vcapture.FrameWidth / 2, vcapture.FrameHeight / 2);
while (vcapture.Read(frame))
{
if (frame.Empty())
{
MessageBox.Show("读取失败");
return;
}
_stopwatch.Restart();
delay = (int)(1000 / videoFps);
detResults = yoloV8.Detect(frame);
//绘制结果
//foreach (DetectionResult r in detResults)
//{
// Cv2.PutText(frame, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
// Cv2.Rectangle(frame, r.Rect, Scalar.Red, thickness: 2);
//}
Cv2.PutText(frame, "preprocessTime:" + yoloV8.preprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 30), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
Cv2.PutText(frame, "inferTime:" + yoloV8.inferTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 70), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
Cv2.PutText(frame, "postprocessTime:" + yoloV8.postprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 110), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
Cv2.PutText(frame, "totalTime:" + yoloV8.totalTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 150), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
Cv2.PutText(frame, "video fps:" + videoFps.ToString("F2"), new OpenCvSharp.Point(10, 190), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
Cv2.PutText(frame, "det fps:" + yoloV8.detFps.ToString("F2"), new OpenCvSharp.Point(10, 230), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
List
var trackOutputs = tracker.Update(track);
foreach (var t in trackOutputs)
{
Rect rect = new Rect((int)t.RectBox.X, (int)t.RectBox.Y, (int)t.RectBox.Width, (int)t.RectBox.Height);
//string txt = $"{t["name"]}-{t.TrackId}:{t.Score:P0}";
string txt = $"{t["name"]}-{t.TrackId}";
Cv2.PutText(frame, txt, new OpenCvSharp.Point(rect.TopLeft.X, rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
Cv2.Rectangle(frame, rect, Scalar.Red, thickness: 2);
}
if (saveDetVideo)
{
vwriter.Write(frame);
}
Cv2.ImShow("DetectionResult 按下ESC,退出", frame);
// for test
// delay = 1;
delay = (int)(delay - _stopwatch.ElapsedMilliseconds);
if (delay <= 0)
{
delay = 1;
}
//Console.WriteLine("delay:" + delay.ToString()) ;
if (Cv2.WaitKey(delay) == 27)
{
break; // 如果按下ESC,退出循环
}
}
Cv2.DestroyAllWindows();
vcapture.Release();
if (saveDetVideo)
{
vwriter.Release();
}
}
}
}
- using ByteTrack;
- using OpenCvSharp;
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.Drawing;
- using System.Threading;
- using System.Windows.Forms;
-
-
- namespace yolov8_OpenVINO_Demo
- {
- public partial class Form2 : Form
- {
- public Form2()
- {
- InitializeComponent();
- }
-
- string imgFilter = "*.*|*.bmp;*.jpg;*.jpeg;*.tiff;*.tiff;*.png";
-
- YoloV8 yoloV8;
- Mat image;
-
- string image_path = "";
- string model_path;
-
- string video_path = "";
- string videoFilter = "*.mp4|*.mp4;";
- VideoCapture vcapture;
- VideoWriter vwriter;
- bool saveDetVideo = false;
- ByteTracker tracker;
-
- /// <summary>
- /// 单图推理
- /// </summary>
- /// <param name="sender"></param>
- /// <param name="e"></param>
- private void button2_Click(object sender, EventArgs e)
- {
-
- if (image_path == "")
- {
- return;
- }
-
- button2.Enabled = false;
- pictureBox2.Image = null;
- textBox1.Text = "";
-
- Application.DoEvents();
-
- image = new Mat(image_path);
-
- List<DetectionResult> detResults = yoloV8.Detect(image);
-
- //绘制结果
- Mat result_image = image.Clone();
- foreach (DetectionResult r in detResults)
- {
- Cv2.PutText(result_image, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
- Cv2.Rectangle(result_image, r.Rect, Scalar.Red, thickness: 2);
- }
-
- if (pictureBox2.Image != null)
- {
- pictureBox2.Image.Dispose();
- }
- pictureBox2.Image = new Bitmap(result_image.ToMemoryStream());
- textBox1.Text = yoloV8.DetectTime();
-
- button2.Enabled = true;
-
- }
-
- /// <summary>
- /// 窗体加载,初始化
- /// </summary>
- /// <param name="sender"></param>
- /// <param name="e"></param>
- private void Form1_Load(object sender, EventArgs e)
- {
- image_path = "test/dog.jpg";
- pictureBox1.Image = new Bitmap(image_path);
-
- model_path = "model/yolov8n.onnx";
-
- yoloV8 = new YoloV8(model_path, "model/lable.txt");
-
- }
-
- /// <summary>
- /// 选择图片
- /// </summary>
- /// <param name="sender"></param>
- /// <param name="e"></param>
- private void button1_Click_1(object sender, EventArgs e)
- {
- OpenFileDialog ofd = new OpenFileDialog();
- ofd.Filter = imgFilter;
- if (ofd.ShowDialog() != DialogResult.OK) return;
-
- pictureBox1.Image = null;
-
- image_path = ofd.FileName;
- pictureBox1.Image = new Bitmap(image_path);
-
- textBox1.Text = "";
- pictureBox2.Image = null;
- }
-
- /// <summary>
- /// 选择视频
- /// </summary>
- /// <param name="sender"></param>
- /// <param name="e"></param>
- private void button4_Click(object sender, EventArgs e)
- {
- OpenFileDialog ofd = new OpenFileDialog();
- ofd.Filter = videoFilter;
- ofd.InitialDirectory = Application.StartupPath + "\\test";
- if (ofd.ShowDialog() != DialogResult.OK) return;
-
- video_path = ofd.FileName;
-
- textBox1.Text = "";
- pictureBox1.Image = null;
- pictureBox2.Image = null;
-
- button3_Click(null, null);
-
- }
-
- /// <summary>
- /// 视频推理
- /// </summary>
- /// <param name="sender"></param>
- /// <param name="e"></param>
- private void button3_Click(object sender, EventArgs e)
- {
- if (video_path == "")
- {
- return;
- }
-
- textBox1.Text = "开始检测";
-
- Application.DoEvents();
-
- Thread thread = new Thread(new ThreadStart(VideoDetection));
-
- thread.Start();
- thread.Join();
-
- textBox1.Text = "检测完成!";
- }
-
- void VideoDetection()
- {
- vcapture = new VideoCapture(video_path);
- if (!vcapture.IsOpened())
- {
- MessageBox.Show("打开视频文件失败");
- return;
- }
-
- tracker = new ByteTracker((int)vcapture.Fps, 200);
-
- Mat frame = new Mat();
- List<DetectionResult> detResults;
-
- // 获取视频的fps
- double videoFps = vcapture.Get(VideoCaptureProperties.Fps);
- // 计算等待时间(毫秒)
- int delay = (int)(1000 / videoFps);
- Stopwatch _stopwatch = new Stopwatch();
-
- if (checkBox1.Checked)
- {
- vwriter = new VideoWriter("out.mp4", FourCC.X264, vcapture.Fps, new OpenCvSharp.Size(vcapture.FrameWidth, vcapture.FrameHeight));
- saveDetVideo = true;
- }
- else
- {
- saveDetVideo = false;
- }
-
- Cv2.NamedWindow("DetectionResult 按下ESC,退出", WindowFlags.Normal);
- Cv2.ResizeWindow("DetectionResult 按下ESC,退出", vcapture.FrameWidth / 2, vcapture.FrameHeight / 2);
-
- while (vcapture.Read(frame))
- {
- if (frame.Empty())
- {
- MessageBox.Show("读取失败");
- return;
- }
-
- _stopwatch.Restart();
-
- delay = (int)(1000 / videoFps);
-
- detResults = yoloV8.Detect(frame);
-
- //绘制结果
- //foreach (DetectionResult r in detResults)
- //{
- // Cv2.PutText(frame, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
- // Cv2.Rectangle(frame, r.Rect, Scalar.Red, thickness: 2);
- //}
-
- Cv2.PutText(frame, "preprocessTime:" + yoloV8.preprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 30), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
- Cv2.PutText(frame, "inferTime:" + yoloV8.inferTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 70), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
- Cv2.PutText(frame, "postprocessTime:" + yoloV8.postprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 110), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
- Cv2.PutText(frame, "totalTime:" + yoloV8.totalTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 150), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
- Cv2.PutText(frame, "video fps:" + videoFps.ToString("F2"), new OpenCvSharp.Point(10, 190), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
- Cv2.PutText(frame, "det fps:" + yoloV8.detFps.ToString("F2"), new OpenCvSharp.Point(10, 230), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
-
- List<Track> track = new List<Track>();
- Track temp;
- foreach (DetectionResult r in detResults)
- {
- RectBox _box = new RectBox(r.Rect.X, r.Rect.Y, r.Rect.Width, r.Rect.Height);
- temp = new Track(_box, r.Confidence, ("label", r.ClassId), ("name", r.Class));
- track.Add(temp);
- }
-
- var trackOutputs = tracker.Update(track);
-
- foreach (var t in trackOutputs)
- {
- Rect rect = new Rect((int)t.RectBox.X, (int)t.RectBox.Y, (int)t.RectBox.Width, (int)t.RectBox.Height);
- //string txt = $"{t["name"]}-{t.TrackId}:{t.Score:P0}";
- string txt = $"{t["name"]}-{t.TrackId}";
- Cv2.PutText(frame, txt, new OpenCvSharp.Point(rect.TopLeft.X, rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
- Cv2.Rectangle(frame, rect, Scalar.Red, thickness: 2);
- }
-
- if (saveDetVideo)
- {
- vwriter.Write(frame);
- }
-
- Cv2.ImShow("DetectionResult 按下ESC,退出", frame);
-
- // for test
- // delay = 1;
- delay = (int)(delay - _stopwatch.ElapsedMilliseconds);
- if (delay <= 0)
- {
- delay = 1;
- }
- //Console.WriteLine("delay:" + delay.ToString()) ;
- if (Cv2.WaitKey(delay) == 27)
- {
- break; // 如果按下ESC,退出循环
- }
- }
-
- Cv2.DestroyAllWindows();
- vcapture.Release();
- if (saveDetVideo)
- {
- vwriter.Release();
- }
-
- }
-
- }
-
- }
YoloV8.cs
- using OpenCvSharp;
- using OpenCvSharp.Dnn;
- using Sdcb.OpenVINO;
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.IO;
- using System.Linq;
- using System.Text;
-
-
- namespace yolov8_OpenVINO_Demo
- {
- public class YoloV8
- {
-
- float[] input_tensor_data;
- float[] outputData;
- List<DetectionResult> detectionResults;
-
- int input_height;
- int input_width;
-
- InferRequest ir;
-
- public string[] class_names;
- int class_num;
- int box_num;
-
- float conf_threshold;
- float nms_threshold;
-
- float ratio_height;
- float ratio_width;
-
- public double preprocessTime;
- public double inferTime;
- public double postprocessTime;
- public double totalTime;
- public double detFps;
-
- public String DetectTime()
- {
- StringBuilder stringBuilder = new StringBuilder();
- stringBuilder.AppendLine($"Preprocess: {preprocessTime:F2}ms");
- stringBuilder.AppendLine($"Infer: {inferTime:F2}ms");
- stringBuilder.AppendLine($"Postprocess: {postprocessTime:F2}ms");
- stringBuilder.AppendLine($"Total: {totalTime:F2}ms");
-
- return stringBuilder.ToString();
- }
-
- public YoloV8(string model_path, string classer_path)
- {
- Model rawModel = OVCore.Shared.ReadModel(model_path);
- PrePostProcessor pp = rawModel.CreatePrePostProcessor();
- PreProcessInputInfo inputInfo = pp.Inputs.Primary;
-
- // inputInfo.TensorInfo.Layout = Sdcb.OpenVINO.Layout.NHWC;
- // inputInfo.ModelInfo.Layout = Sdcb.OpenVINO.Layout.NCHW;
-
- Model m = pp.BuildModel();
- CompiledModel cm = OVCore.Shared.CompileModel(m, "AUTO");
- ir = cm.CreateInferRequest();
-
- class_names = File.ReadAllLines(classer_path, Encoding.UTF8);
- class_num = class_names.Length;
-
- input_height = 640;
- input_width = 640;
-
- box_num = 8400;
-
- conf_threshold = 0.5f;
- nms_threshold = 0.5f;
-
- detectionResults = new List<DetectionResult>();
- }
-
- void Preprocess(Mat image)
- {
- //图片缩放
- int height = image.Rows;
- int width = image.Cols;
- Mat temp_image = image.Clone();
- if (height > input_height || width > input_width)
- {
- float scale = Math.Min((float)input_height / height, (float)input_width / width);
- OpenCvSharp.Size new_size = new OpenCvSharp.Size((int)(width * scale), (int)(height * scale));
- Cv2.Resize(image, temp_image, new_size);
- }
- ratio_height = (float)height / temp_image.Rows;
- ratio_width = (float)width / temp_image.Cols;
- Mat input_img = new Mat();
- Cv2.CopyMakeBorder(temp_image, input_img, 0, input_height - temp_image.Rows, 0, input_width - temp_image.Cols, BorderTypes.Constant, 0);
-
- //归一化
- input_img.ConvertTo(input_img, MatType.CV_32FC3, 1.0 / 255);
-
- input_tensor_data = Common.ExtractMat(input_img);
-
- input_img.Dispose();
- temp_image.Dispose();
- }
-
- void Postprocess(float[] outputData)
- {
- detectionResults.Clear();
-
- float[] data = Common.Transpose(outputData, class_num + 4, box_num);
-
- float[] confidenceInfo = new float[class_num];
- float[] rectData = new float[4];
-
- List<DetectionResult> detResults = new List<DetectionResult>();
-
- for (int i = 0; i < box_num; i++)
- {
- Array.Copy(data, i * (class_num + 4), rectData, 0, 4);
- Array.Copy(data, i * (class_num + 4) + 4, confidenceInfo, 0, class_num);
-
- float score = confidenceInfo.Max(); // 获取最大值
-
- int maxIndex = Array.IndexOf(confidenceInfo, score); // 获取最大值的位置
-
- int _centerX = (int)(rectData[0] * ratio_width);
- int _centerY = (int)(rectData[1] * ratio_height);
- int _width = (int)(rectData[2] * ratio_width);
- int _height = (int)(rectData[3] * ratio_height);
-
- detResults.Add(new DetectionResult(
- maxIndex,
- class_names[maxIndex],
- new Rect(_centerX - _width / 2, _centerY - _height / 2, _width, _height),
- score));
- }
-
- //NMS
- CvDnn.NMSBoxes(detResults.Select(x => x.Rect), detResults.Select(x => x.Confidence), conf_threshold, nms_threshold, out int[] indices);
- detResults = detResults.Where((x, index) => indices.Contains(index)).ToList();
-
- detectionResults = detResults;
- }
-
- internal List<DetectionResult> Detect(Mat image)
- {
-
- var t1 = Cv2.GetTickCount();
-
- Stopwatch stopwatch = new Stopwatch();
- stopwatch.Start();
-
- Preprocess(image);
-
- preprocessTime = stopwatch.Elapsed.TotalMilliseconds;
- stopwatch.Restart();
-
- using (Tensor input_x = Tensor.FromArray(input_tensor_data, new Shape(1,3, 640, 640)))
- {
- ir.Inputs[0] = input_x;
- }
-
- ir.Run();
-
- inferTime = stopwatch.Elapsed.TotalMilliseconds;
- stopwatch.Restart();
-
- outputData = ir.Outputs[0].GetData<float>().ToArray();
-
- Postprocess(outputData);
-
- postprocessTime = stopwatch.Elapsed.TotalMilliseconds;
- stopwatch.Stop();
-
- totalTime = preprocessTime + inferTime + postprocessTime;
-
- detFps = (double)stopwatch.Elapsed.TotalSeconds / (double)stopwatch.Elapsed.Ticks;
-
- var t2 = Cv2.GetTickCount();
-
- detFps = 1 / ((t2 - t1) / Cv2.GetTickFrequency());
-
- return detectionResults;
-
- }
-
- }
- }
ByteTracker.cs
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Threading.Tasks;
-
- namespace ByteTrack
- {
- public class ByteTracker
- {
- readonly float _trackThresh;
- readonly float _highThresh;
- readonly float _matchThresh;
- readonly int _maxTimeLost;
-
- int _frameId = 0;
- int _trackIdCount = 0;
-
- readonly List<Track> _trackedTracks = new List<Track>(100);
- readonly List<Track> _lostTracks = new List<Track>(100);
- List<Track> _removedTracks = new List<Track>(100);
-
- public ByteTracker(int frameRate = 30, int trackBuffer = 30, float trackThresh = 0.5f, float highThresh = 0.6f, float matchThresh = 0.8f)
- {
- _trackThresh = trackThresh;
- _highThresh = highThresh;
- _matchThresh = matchThresh;
- _maxTimeLost = (int)(frameRate / 30.0 * trackBuffer);
- }
-
- /// <summary>
- ///
- /// </summary>
- /// <param name="objects"></param>
- /// <returns></returns>
- public IList<Track> Update(List<Track> tracks)
- {
- #region Step 1: Get detections
- _frameId++;
-
- // Create new Tracks using the result of object detection
- List<Track> detTracks = new List<Track>();
- List<Track> detLowTracks = new List<Track>();
-
- foreach (var obj in tracks)
- {
- if (obj.Score >= _trackThresh)
- {
- detTracks.Add(obj);
- }
- else
- {
- detLowTracks.Add(obj);
- }
- }
-
- // Create lists of existing STrack
- List<Track> activeTracks = new List<Track>();
- List<Track> nonActiveTracks = new List<Track>();
-
- foreach (var trackedTrack in _trackedTracks)
- {
- if (!trackedTrack.IsActivated)
- {
- nonActiveTracks.Add(trackedTrack);
- }
- else
- {
- activeTracks.Add(trackedTrack);
- }
- }
-
- var trackPool = activeTracks.Union(_lostTracks).ToArray();
-
- // Predict current pose by KF
- foreach (var track in trackPool)
- {
- track.Predict();
- }
- #endregion
-
- #region Step 2: First association, with IoU
- List<Track> currentTrackedTracks = new List<Track>();
- Track[] remainTrackedTracks;
- Track[] remainDetTracks;
- List<Track> refindTracks = new List<Track>();
- {
- var dists = CalcIouDistance(trackPool, detTracks);
- LinearAssignment(dists, trackPool.Length, detTracks.Count, _matchThresh,
- out var matchesIdx,
- out var unmatchTrackIdx,
- out var unmatchDetectionIdx);
-
- foreach (var matchIdx in matchesIdx)
- {
- var track = trackPool[matchIdx[0]];
- var det = detTracks[matchIdx[1]];
- if (track.State == TrackState.Tracked)
- {
- track.Update(det, _frameId);
- currentTrackedTracks.Add(track);
- }
- else
- {
- track.ReActivate(det, _frameId);
- refindTracks.Add(track);
- }
- }
-
- remainDetTracks = unmatchDetectionIdx.Select(unmatchIdx => detTracks[unmatchIdx]).ToArray();
- remainTrackedTracks = unmatchTrackIdx
- .Where(unmatchIdx => trackPool[unmatchIdx].State == TrackState.Tracked)
- .Select(unmatchIdx => trackPool[unmatchIdx])
- .ToArray();
- }
- #endregion
-
- #region Step 3: Second association, using low score dets
- List<Track> currentLostTracks = new List<Track>();
- {
- var dists = CalcIouDistance(remainTrackedTracks, detLowTracks);
- LinearAssignment(dists, remainTrackedTracks.Length, detLowTracks.Count, 0.5f,
- out var matchesIdx,
- out var unmatchTrackIdx,
- out var unmatchDetectionIdx);
-
- foreach (var matchIdx in matchesIdx)
- {
- var track = remainTrackedTracks[matchIdx[0]];
- var det = detLowTracks[matchIdx[1]];
- if (track.State == TrackState.Tracked)
- {
- track.Update(det, _frameId);
- currentTrackedTracks.Add(track);
- }
- else
- {
- track.ReActivate(det, _frameId);
- refindTracks.Add(track);
- }
- }
-
- foreach (var unmatchTrack in unmatchTrackIdx)
- {
- var track = remainTrackedTracks[unmatchTrack];
- if (track.State != TrackState.Lost)
- {
- track.MarkAsLost();
- currentLostTracks.Add(track);
- }
- }
- }
- #endregion
-
- #region Step 4: Init new tracks
- List<Track> currentRemovedTracks = new List<Track>();
- {
- // Deal with unconfirmed tracks, usually tracks with only one beginning frame
- var dists = CalcIouDistance(nonActiveTracks, remainDetTracks);
- LinearAssignment(dists, nonActiveTracks.Count, remainDetTracks.Length, 0.7f,
- out var matchesIdx,
- out var unmatchUnconfirmedIdx,
- out var unmatchDetectionIdx);
-
- foreach (var matchIdx in matchesIdx)
- {
- nonActiveTracks[matchIdx[0]].Update(remainDetTracks[matchIdx[1]], _frameId);
- currentTrackedTracks.Add(nonActiveTracks[matchIdx[0]]);
- }
-
- foreach (var unmatchIdx in unmatchUnconfirmedIdx)
- {
- var track = nonActiveTracks[unmatchIdx];
- track.MarkAsRemoved();
- currentRemovedTracks.Add(track);
- }
-
- // Add new stracks
- foreach (var unmatchIdx in unmatchDetectionIdx)
- {
- var track = remainDetTracks[unmatchIdx];
- if (track.Score < _highThresh)
- continue;
-
- _trackIdCount++;
- track.Activate(_frameId, _trackIdCount);
- currentTrackedTracks.Add(track);
- }
- }
- #endregion
-
- #region Step 5: Update state
- foreach (var lostTrack in _lostTracks)
- {
- if (_frameId - lostTrack.FrameId > _maxTimeLost)
- {
- lostTrack.MarkAsRemoved();
- currentRemovedTracks.Add(lostTrack);
- }
- }
-
- var trackedTracks = currentTrackedTracks.Union(refindTracks).ToArray();
- var lostTracks = _lostTracks.Except(trackedTracks).Union(currentLostTracks).Except(_removedTracks).ToArray();
- _removedTracks = _removedTracks.Union(currentRemovedTracks).ToList();
- RemoveDuplicateStracks(trackedTracks, lostTracks);
- #endregion
-
- return _trackedTracks.Where(track => track.IsActivated).ToArray();
- }
-
- /// <summary>
- ///
- /// </summary>
- /// <param name="aTracks"></param>
- /// <param name="bTracks"></param>
- /// <param name="aResults"></param>
- /// <param name="bResults"></param>
- void RemoveDuplicateStracks(IList<Track> aTracks, IList<Track> bTracks)
- {
- _trackedTracks.Clear();
- _lostTracks.Clear();
-
- List<(int, int)> overlappingCombinations;
- var ious = CalcIouDistance(aTracks, bTracks);
-
- if (ious is null)
- overlappingCombinations = new List<(int, int)>();
- else
- {
- var rows = ious.GetLength(0);
- var cols = ious.GetLength(1);
- overlappingCombinations = new List<(int, int)>(rows * cols / 2);
- for (var i = 0; i < rows; i++)
- for (var j = 0; j < cols; j++)
- if (ious[i, j] < 0.15f)
- overlappingCombinations.Add((i, j));
- }
-
- var aOverlapping = aTracks.Select(x => false).ToArray();
- var bOverlapping = bTracks.Select(x => false).ToArray();
-
- foreach (var (aIdx, bIdx) in overlappingCombinations)
- {
- var timep = aTracks[aIdx].FrameId - aTracks[aIdx].StartFrameId;
- var timeq = bTracks[bIdx].FrameId - bTracks[bIdx].StartFrameId;
- if (timep > timeq)
- bOverlapping[bIdx] = true;
- else
- aOverlapping[aIdx] = true;
- }
-
- for (var ai = 0; ai < aTracks.Count; ai++)
- if (!aOverlapping[ai])
- _trackedTracks.Add(aTracks[ai]);
-
- for (var bi = 0; bi < bTracks.Count; bi++)
- if (!bOverlapping[bi])
- _lostTracks.Add(bTracks[bi]);
- }
-
- /// <summary>
- ///
- /// </summary>
- /// <param name="costMatrix"></param>
- /// <param name="costMatrixSize"></param>
- /// <param name="costMatrixSizeSize"></param>
- /// <param name="thresh"></param>
- /// <param name="matches"></param>
- /// <param name="aUnmatched"></param>
- /// <param name="bUnmatched"></param>
- void LinearAssignment(float[,] costMatrix, int costMatrixSize, int costMatrixSizeSize, float thresh, out IList<int[]> matches, out IList<int> aUnmatched, out IList<int> bUnmatched)
- {
- matches = new List<int[]>();
- if (costMatrix is null)
- {
- aUnmatched = Enumerable.Range(0, costMatrixSize).ToArray();
- bUnmatched = Enumerable.Range(0, costMatrixSizeSize).ToArray();
- return;
- }
-
- bUnmatched = new List<int>();
- aUnmatched = new List<int>();
-
- var (rowsol, colsol) = Lapjv.Exec(costMatrix, true, thresh);
-
- for (var i = 0; i < rowsol.Length; i++)
- {
- if (rowsol[i] >= 0)
- matches.Add(new int[] { i, rowsol[i] });
- else
- aUnmatched.Add(i);
- }
-
- for (var i = 0; i < colsol.Length; i++)
- if (colsol[i] < 0)
- bUnmatched.Add(i);
- }
-
- /// <summary>
- ///
- /// </summary>
- /// <param name="aRects"></param>
- /// <param name="bRects"></param>
- /// <returns></returns>
- static float[,] CalcIous(IList<RectBox> aRects, IList<RectBox> bRects)
- {
- if (aRects.Count * bRects.Count == 0) return null;
-
- var ious = new float[aRects.Count, bRects.Count];
- for (var bi = 0; bi < bRects.Count; bi++)
- for (var ai = 0; ai < aRects.Count; ai++)
- ious[ai, bi] = bRects[bi].CalcIoU(aRects[ai]);
-
- return ious;
- }
-
- /// <summary>
- ///
- /// </summary>
- /// <param name="aTtracks"></param>
- /// <param name="bTracks"></param>
- /// <returns></returns>
- static float[,] CalcIouDistance(IEnumerable<Track> aTtracks, IEnumerable<Track> bTracks)
- {
- var aRects = aTtracks.Select(x => x.RectBox).ToArray();
- var bRects = bTracks.Select(x => x.RectBox).ToArray();
-
- var ious = CalcIous(aRects, bRects);
- if (ious is null) return null;
-
- var rows = ious.GetLength(0);
- var cols = ious.GetLength(1);
- var matrix = new float[rows, cols];
- for (var i = 0; i < rows; i++)
- for (var j = 0; j < cols; j++)
- matrix[i, j] = 1 - ious[i, j];
-
- return matrix;
- }
- }
- }
下载
参考


评论记录:
回复评论: