搭建YOLOv10环境 训练+推理+模型评估

搭建YOLOv10环境 训练+推理+模型评估

码农世界 2024-05-30 前端 83 次浏览 0个评论

文章目录

  • 前言
  • 一、环境搭建
    • 必要环境
    • 1. 创建yolov10虚拟环境
    • 2. 下载pytorch (pytorch版本>=1.8)
    • 3. 下载YOLOv10源码
    • 4. 安装所需要的依赖包
    • 二、推理测试
      • 1. 将如下代码复制到ultralytics文件夹同级目录下并运行 即可得到推理结果
      • 2. 关键参数
      • 三、训练及评估
        • 1. 数据结构介绍
        • 2. 配置文件修改
        • 3. 训练/评估模型
        • 4. 关键参数
        • 5. 单独对训练好的模型将进行评估
        • 总结

          前言

          本文将详细介绍跑通YOLOv10的流程,并给各位提供用于训练、评估和模型推理的脚本

          一、环境搭建

          必要环境

          本文使用Windows10+Python3.8+CUDA10.2+CUDNN8.0.4作为基础环境,使用30系或40系显卡的小伙伴请安装11.0以上版本的CUDA

          1. 创建yolov10虚拟环境

          conda create -n yolov10 python=3.8
          

          2. 下载pytorch (pytorch版本>=1.8)

          pip install torch==1.9.1+cu102 torchvision==0.10.1+cu102 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
          

          若使用的是AMD显卡或不使用GPU的同学 可以通过以下命令可以安装CPU版本

          pip install torch==1.9.1+cpu torchvision==0.10.1+cpu torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
          

          3. 下载YOLOv10源码

          地址:https://github.com/THU-MIG/yolov10

          4. 安装所需要的依赖包

          pip install -r requirements.txt
          

          二、推理测试

          1. 将如下代码复制到ultralytics文件夹同级目录下并运行 即可得到推理结果

          import cv2
          from ultralytics import YOLOv10
          import os
          import argparse
          import time
          import torch
          parser = argparse.ArgumentParser()
          # 检测参数
          parser.add_argument('--weights', default=r"yolov10n.pt", type=str, help='weights path')
          parser.add_argument('--source', default=r"images", type=str, help='img or video(.mp4)path')
          parser.add_argument('--save', default=r"./save", type=str, help='save img or video path')
          parser.add_argument('--vis', default=True, action='store_true', help='visualize image')
          parser.add_argument('--conf_thre', type=float, default=0.5, help='conf_thre')
          parser.add_argument('--iou_thre', type=float, default=0.5, help='iou_thre')
          opt = parser.parse_args()
          device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
          def get_color(idx):
              idx = idx * 3
              color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
              return color
          class Detector(object):
              def __init__(self, weight_path, conf_threshold=0.5, iou_threshold=0.5):
                  self.device = device
                  self.model = YOLOv10(weight_path)
                  self.conf_threshold = conf_threshold
                  self.iou_threshold = iou_threshold
                  self.names = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train',
                                7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign',
                                12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep',
                                19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella',
                                26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard',
                                32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard',
                                37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork',
                                43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange',
                                50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair',
                                57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv',
                                63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave',
                                69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase',
                                76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
              def detect_image(self, img_bgr):
                  results = self.model(img_bgr, verbose=True, conf=self.conf_threshold,
                                       iou=self.iou_threshold, device=self.device)
                  bboxes_cls = results[0].boxes.cls
                  bboxes_conf = results[0].boxes.conf
                  bboxes_xyxy = results[0].boxes.xyxy.cpu().numpy().astype('uint32')
                  for idx in range(len(bboxes_cls)):
                      box_cls = int(bboxes_cls[idx])
                      bbox_xyxy = bboxes_xyxy[idx]
                      bbox_label = self.names[box_cls]
                      box_conf = f"{bboxes_conf[idx]:.2f}"
                      xmax, ymax, xmin, ymin = bbox_xyxy[2], bbox_xyxy[3], bbox_xyxy[0], bbox_xyxy[1]
                      img_bgr = cv2.rectangle(img_bgr, (xmin, ymin), (xmax, ymax), get_color(box_cls + 3), 2)
                      cv2.putText(img_bgr, f'{str(bbox_label)}/{str(box_conf)}', (xmin, ymin - 10),
                                  cv2.FONT_HERSHEY_SIMPLEX, 0.5, get_color(box_cls + 3), 2)
                  return img_bgr
          # Example usage
          if __name__ == '__main__':
              model = Detector(weight_path=opt.weights, conf_threshold=opt.conf_thre, iou_threshold=opt.iou_thre)
              images_format = ['.png', '.jpg', '.jpeg', '.JPG', '.PNG', '.JPEG']
              video_format = ['mov', 'MOV', 'mp4', 'MP4']
              if os.path.join(opt.source).split(".")[-1] not in video_format:
                  image_names = [name for name in os.listdir(opt.source) for item in images_format if
                                 os.path.splitext(name)[1] == item]
                  for img_name in image_names:
                      img_path = os.path.join(opt.source, img_name)
                      img_ori = cv2.imread(img_path)
                      img_vis = model.detect_image(img_ori)
                      img_vis = cv2.resize(img_vis, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_NEAREST)
                      cv2.imwrite(os.path.join(opt.save, img_name), img_vis)
                      if opt.vis:
                          cv2.imshow(img_name, img_vis)
                          cv2.waitKey(0)
                          cv2.destroyAllWindows()
              else:
                  capture = cv2.VideoCapture(opt.source)
                  fps = capture.get(cv2.CAP_PROP_FPS)
                  size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
                          int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
                  fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
                  outVideo = cv2.VideoWriter(os.path.join(opt.save, os.path.basename(opt.source).split('.')[-2] + "_out.mp4"),
                                             fourcc,
                                             fps, size)
                  while True:
                      ret, frame = capture.read()
                      if not ret:
                          break
                      start_frame_time = time.perf_counter()
                      img_vis = model.detect_image(frame)
                      # 结束计时
                      end_frame_time = time.perf_counter()  # 使用perf_counter进行时间记录
                      # 计算每帧处理的FPS
                      elapsed_time = end_frame_time - start_frame_time
                      if elapsed_time == 0:
                          fps_estimation = 0.0
                      else:
                          fps_estimation = 1 / elapsed_time
                      h, w, c = img_vis.shape
                      cv2.putText(img_vis, f"FPS: {fps_estimation:.2f}", (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 2)
                      outVideo.write(img_vis)
                      cv2.imshow('detect', img_vis)
                      cv2.waitKey(1)
                  capture.release()
                  outVideo.release()
          

          2. 关键参数

          1. 测试图片:–source 变量后填写图像文件夹路径 如:default=r"images"

          2. 测试视频:–source 变量后填写视频路径 如:default=r"video.mp4"

          推理图像效果:

          推理视频效果:

          三、训练及评估

          1. 数据结构介绍

          这里使用的数据集是VOC2007,用留出法将数据按9:1的比例划分成了训练集和验证集

          下载地址如下:

          链接:https://pan.baidu.com/s/1FmbShVF1SQOZfjncj3OKJA?pwd=i7od

          提取码:i7od

          2. 配置文件修改

          3. 训练/评估模型

          将如下代码复制到ultralytics文件夹同级目录下并运行 即可开始训练

          # -*- coding:utf-8 -*-
          from ultralytics import YOLOv10
          import argparse
          # 解析命令行参数
          parser = argparse.ArgumentParser(description='Train or validate YOLO model.')
          # train用于训练原始模型  val 用于得到精度指标
          parser.add_argument('--mode', type=str, default='train', help='Mode of operation.')
          # 预训练模型
          parser.add_argument('--weights', type=str, default='yolov10n.pt', help='Path to model file.')
          # 数据集存放路径
          parser.add_argument('--data', type=str, default='VOC2007/data.yaml', help='Path to data file.')
          parser.add_argument('--epoch', type=int, default=200, help='Number of epochs.')
          parser.add_argument('--batch', type=int, default=8, help='Batch size.')
          parser.add_argument('--workers', type=int, default=0, help='Number of workers.')
          parser.add_argument('--device', type=str, default='0', help='Device to use.')
          parser.add_argument('--name', type=str, default='', help='Name data file.')
          args = parser.parse_args()
          def train(model, data, epoch, batch, workers, device, name):
              model.train(data=data, epochs=epoch, batch=batch, workers=workers, device=device, name=name)
          def validate(model, data, batch, workers, device, name):
              model.val(data=data, batch=batch, workers=workers, device=device, name=name)
          def main():
              model = YOLOv10(args.weights)
              if args.mode == 'train':
                  train(model, args.data, args.epoch, args.batch, args.workers, args.device, args.name)
              else:
                  validate(model, args.data, args.batch, args.workers, args.device, args.name)
          if __name__ == '__main__':
              main()
          

          4. 关键参数

          1. 模式选择:

          –mode train: 开始训练模型

          –mode val: 进行模型验证

          2. 训练轮数: 通过 --epoch 参数设置训练轮数,默认为200轮。该参数控制模型在训练集上迭代的次数,增加轮数有助于提升模型性能,但同时也会增加训练时间。

          3. 训练批次: 通过 --batch 参数设置训练批次大小,一般设置为2的倍数,如8或16。批次大小决定了每次参数更新时使用的样本数量,较大的批次有助于加速收敛,但会增加显存占用,需根据实际显存大小进行调整

          4. 训练数据加载进程数: 通过 --workers 参数设置数据加载进程数,默认为8。该参数控制了在训练期间用于加载和预处理数据的进程数量。增加进程数可以加快数据的加载速度,linux系统下一般设置为8或16,windows系统设置为0。

          训练过程:

          训练结束后模型已经训练过程默认会保存到runs/detect/exp路径下

          5. 单独对训练好的模型将进行评估

          1. 将 --mode变量后改为val 如:default=“val”

          2. 将 --weights变量后改为要单独评估的模型路径 如:default=r"runs/detect/exp/weights/best.pt"

          评估过程:


          总结

          yolo是真卷呐,版本号一会儿一变的,v9还没看呢v10已经出来了…

          最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

          学习交流群:995760755


转载请注明来自码农世界,本文标题:《搭建YOLOv10环境 训练+推理+模型评估》

百度分享代码,如果开启HTTPS请参考李洋个人博客
每一天,每一秒,你所做的决定都会改变你的人生!

发表评论

快捷回复:

评论列表 (暂无评论,83人围观)参与讨论

还没有评论,来说两句吧...

Top