1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
4.3 划分数据集
本文将数据集按照8:1:1的比例划分为训练集、验证集和测试集。
import os
import shutil
import random
from tqdm import tqdm
def split_img ( img_path, label_path, split_list) :
try :
Data = "E:/data/dataset"
train_img_dir = os. path. join( Data, 'images' , 'train' )
val_img_dir = os. path. join( Data, 'images' , 'val' )
test_img_dir = os. path. join( Data, 'images' , 'test' )
train_label_dir = os. path. join( Data, 'labels' , 'train' )
val_label_dir = os. path. join( Data, 'labels' , 'val' )
test_label_dir = os. path. join( Data, 'labels' , 'test' )
os. makedirs( train_img_dir, exist_ok= True )
os. makedirs( train_label_dir, exist_ok= True )
os. makedirs( val_img_dir, exist_ok= True )
os. makedirs( val_label_dir, exist_ok= True )
os. makedirs( test_img_dir, exist_ok= True )
os. makedirs( test_label_dir, exist_ok= True )
except Exception as e:
print ( f"Error creating directories: { e} " )
train, val, test = split_list
all_img = os. listdir( img_path)
all_img_path = [ os. path. join( img_path, img) for img in all_img]
train_img = random. sample( all_img_path, int ( train * len ( all_img_path) ) )
train_img_copy = [ os. path. join( train_img_dir, os. path. basename( img) ) for img in train_img]
train_label = [ toLabelPath( img, label_path) for img in train_img]
for i in tqdm( range ( len ( train_img) ) , desc= 'train ' , ncols= 80 , unit= 'img' ) :
_copy( train_img[ i] , train_img_dir)
_copy( train_label[ i] , train_label_dir)
all_img_path. remove( train_img[ i] )
val_img = random. sample( all_img_path, int ( val / ( val + test) * len ( all_img_path) ) )
val_label = [ toLabelPath( img, label_path) for img in val_img]
for i in tqdm( range ( len ( val_img) ) , desc= 'val ' , ncols= 80 , unit= 'img' ) :
_copy( val_img[ i] , val_img_dir)
_copy( val_label[ i] , val_label_dir)
all_img_path. remove( val_img[ i] )
test_img = all_img_path
test_label = [ toLabelPath( img, label_path) for img in test_img]
for i in tqdm( range ( len ( test_img) ) , desc= 'test ' , ncols= 80 , unit= 'img' ) :
_copy( test_img[ i] , test_img_dir)
_copy( test_label[ i] , test_label_dir)
def _copy ( from_path, to_path) :
shutil. copy( from_path, to_path)
def toLabelPath ( img_path, label_path) :
img = os. path. basename( img_path)
label = img. split( '.jpg' ) [ 0 ] + '.txt'
return os. path. join( label_path, label)
def main ( ) :
img_path = "E:/data/images"
label_path = "E:/data/txt_label"
split_list = [ 0.8 , 0.1 , 0.1 ]
split_img( img_path, label_path, split_list)
if __name__ == '__main__' :
main( )
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
4.4 配置文件
在dataset文件夹下新建一个data.yaml文件(可以自定义命名),用来存放训练集和验证集的路径以及目标类别。
data.yaml内容如下: 至此,自定义数据集已创建完毕,接下来就是训练模型了。
5. 模型训练
5.1 模型下载
在YOLOv8的GitHub开源网址上下载源代码以及对应版本的模型。
也可以通过终端输入一下命令下载代码 git clone https://github.com/ultralytics/ultralytics.git
5.2 模型训练
一种方式是使用命令进行训练:
yolo detect train data= E: / data/ dataset/ data. yaml model= . / yolov8s. yaml batch= 32 epochs= 150 imgsz= 640 workers= 12 device= 0 pretrained= False
yolo detect train data= E: / data/ dataset/ data. yaml model= . / yolov8s. pt batch= 32 epochs= 150 imgsz= 640 workers= 12 device= 0
yolo detect train data= E: / data/ dataset/ data. yaml model= . / mymodel. yaml batch= 32 epochs= 150 imgsz= 640 workers= 12 device= 0 pretrained= yolov8s. pt
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
另一种方式Python实现:
import os
import torch
from ultralytics import YOLO
def main ( ) :
device = torch. device( 'cuda' if torch. cuda. is_available( ) else 'cpu' )
print ( f"Using device: { device} " )
model = YOLO( 'yolov8s.pt' )
data_path = 'E:/data/dataset/data.yaml'
epochs = 150
imgsz = 640
batch = 32
workers = 12
model. train( data= data_path, epochs= epochs, imgsz= imgsz, batch= batch, workers= workers, device= device)
if __name__ == "__main__" :
main( )
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
6. 模型验证
对验证集或测试集进行验证,命令如下:
yolo task= detect mode= val split= val model= . / runs/ detect/ train/ weights/ best. pt data= E: / data/ dataset/ data. yaml batch= 1 workers= 0 imgsz= 640 device= 0
yolo task= detect mode= val split= test model= . / runs/ detect/ train/ weights/ best. pt data= E: / data/ dataset/ data. yaml batch= 1 workers= 0 imgsz= 640 device= 0
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
或
from ultralytics import YOLO
model_path = r'.\runs\detect\train\weights\best.pt'
data_path = r'E:\data\dataset\data.yaml'
model = YOLO( model_path)
results = model. val( data= data_path, split= 'test' , batch= 1 , workers= 0 , imgsz= 640 , device= '0' )
print ( results)
mAP50 = results. results_dict[ 'metrics/mAP50' ]
print ( f"mAP@50: { mAP50} " )
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
7. 模型预测
yolo task= detect mode= predict model= . / runs/ detect/ train/ weights/ best. pt source= E: / data/ 1 . jpg save= True device= 0
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
或
from ultralytics import YOLO
import os
model_path = r'.\runs\detect\train\weights\best.pt'
image_path = r'E:\data\1.jpg'
device = '0'
model = YOLO( model_path)
results = model. predict( source= image_path, save= True , device= device)
for r in results:
print ( f"类别: { r. boxes. cls} " )
print ( f"边界框坐标: { r. boxes. xyxy} " )
save_dir = os. path. join( 'runs' , 'detect' , 'predict' )
save_path = os. path. join( save_dir, os. path. basename( image_path) )
print ( f"带有预测标注的图像已保存至: { save_path} " )
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
8. 模型导出
yolo task= detect mode= export model= . / runs/ detect/ train/ weights/ best. pt format = onnx
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
或
from ultralytics import YOLO
model = YOLO( './runs/detect/train/weights/best.pt' )
model. export( format = 'onnx' )
class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}">
评论记录:
回复评论: