yolov8训练自己的数据集遇到的问题

news/2024/4/25 0:38:04/

训练分类模型

1.如何更改模型的类别数nc

根据本地模型配置文件.yaml可以设置nc
在这里插入图片描述
但是,这里无法用到预训练模型.pt模型文件,预训练模型的权重参数是在大数据集上训练得到的,泛化性能可能比较好,所以,下载了官方的分类模型yolov8n-cls.pt

由于需要魔改yolov8,所以下载了官方源码,在default.yaml配置文件中是各种超参数,包括data与model路径的设置

将data路径设置为本地路径,或者调用数据配置文件coco128.yaml文件进行更改也可,这里有个问题,由于我也利用pip安装了yolov8的源码库,导致在利用coco128.yaml文件时,代码会在官方下载数据,导致报错,可能与这部分代码有关,目前暂未解释,直接定位本地文件夹更方便。

model路径设置为下载的yolov8n-cls.pt本地路径,或者模型配置文件yolov8n.yaml路径,这里选择前者。
问题来了,官方配置文件的nc是1000,由上图模型配置文件也可看出,先说方法,在task.py文件中找到attempt_load_one_weight这个函数,这个函数是用来下载.pt模型文件的,在train.py文件中的下面函数也可以看到

def setup_model(self):"""load/create/download model for any task"""# classification models require special handlingif isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup neededreturnmodel = str(self.model)# Load a YOLO model locally, from torchvision, or from Ultralytics assetsif model.endswith(".pt"):self.model, _ = attempt_load_one_weight(model, device='cpu')for p in self.model.parameters():p.requires_grad = True  # for trainingelif model.endswith(".yaml"):self.model = self.get_model(cfg=model)

解决办法:
attempt_load_one_weight这个函数中添加代码如下,

def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):# Loads a single model weightsckpt = torch_safe_load(weight)  # load ckptargs = {**DEFAULT_CFG_DICT, **ckpt['train_args']}  # combine model and default args, preferring model argsmodel = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model##########################change_nc####################################方法一:nc = 5ch = 256m = model.model[-1]  # last layer#ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels  # ch into modulec = Classify(ch, nc)  # Classify()c.i, c.f, c.type = m.i, m.f, 'models.common.Classify'  # index, from, typemodel.model[-1] = c  # replaceprint("############################################################")print(model)# print("model.layers:",len(model.layers))# for layer in model.layers[:-10]:#     layer.trainable = False#######################################################################

nc为自己数据的类别数,ch为模型最后一层的输入通道层数,这里由于模型没有layers参数,模型中卷积层没有in_channel参数,所以无法直接调用,所以咱不能进行冻结层训练,考虑到模型层数也不多,先暂且这样吧,ch可以将原.pt模型先转为onnx通过netron进行查看

这里要借鉴了ClassificationModel类下的_from_detection_model函数,根据onnx对比moudles.py文件中的Classify函数找到ch=256,因为这里的.pt模型文件中的最后一层为
![# YOLOv8.0n head
head:- [-1, 1, Classify, [nc]]  # Classify](https://img-blog.csdnimg.cn/9b83d6b87de54c26b044ed4eaac5b1aa.png)
class Classify(nn.Module):# YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)def __init__(self, c1, c2, k=1, s=1, p=None, g=1):  # ch_in, ch_out, kernel, stride, padding, groupssuper().__init__()c_ = 1280  # efficientnet_b0 sizeself.conv = Conv(c1, c_, k, s, autopad(k, p), g)self.pool = nn.AdaptiveAvgPool2d(1)  # to x(b,c_,1,1)self.drop = nn.Dropout(p=0.0, inplace=True)self.linear = nn.Linear(c_, c2)  # to x(b,c2)def forward(self, x):if isinstance(x, list):x = torch.cat(x, 1)x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))return x if self.training else x.softmax(1)

注意:在训练完之后需要注释掉添加的代码,否则在export,val,predict模型的时候,比如onnx会调用模型下载函数,导致模型破坏。

在进行测试的时候,竟然发现原来没有改动类别数训练之后的模型发现虽然类别数是1000.但是在进行测试的时候,以及转换为onnx之后进行测试发现结果没有变化,甚至高一个点,什么鬼,据我猜测,数据没有发生变化,我直接设置的是本地路径,所以以经是设置为5类了,现在就剩模型最后一层的类别数不一样,也就是说在训练的时候,无论你模型设置多少类,只要数据集类数确定了,那么最后训练的结果是一样的,就好像多余的类别数默认输出为0,但是前面的类别数输出结果不变。

2.如何将pt模型的三通道改为单通道

1.输入层的修改,将模型第一个卷积层的输入通道数从3改为1,权重改为单通道,具体代码如下:
原模型第一层的conv的输入通道数为3,权重通道为[:, :3, :, :]
attempt_load_one_weight这个函数中添加代码如下,

	#修改输入通道数,权重通道数model.model[0].conv.in_channels = 1model.model[0].conv.weight = torch.nn.Parameter(model.model[0].conv.weight[:, :1, :, :])#model.model[0].conv.weight.data = model.model[0].conv.weight.data[:, :1, :, :]

2.修改数据集图片的维度,将其由三通道rgb转为单通道灰度图
在dataset.py文件中修改def getitem(self, i)函数

def __getitem__(self, i):f, j, fn, im = self.samples[i]  # filename, index, filename.with_suffix('.npy'), imageif self.cache_ram and im is None:im = self.samples[i][3] = cv2.imread(f)elif self.cache_disk:if not fn.exists():  # load npynp.save(fn.as_posix(), cv2.imread(f))im = np.load(fn)else:  # read imageim = cv2.imread(f)  # BGRif self.album_transforms:sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]#源代码else:#修改为单通道import numpy as npimage=cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)sample = self.torch_transforms(image)#sample = self.torch_transforms(im)#原代码

然后,torch_transforms函数索引到augment.py函数中的classify_transforms函数,在这个函数中分别修改CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)函数

def classify_transforms(size=224):# Transforms to apply if albumentations not installedif not isinstance(size, int):raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

CenterCrop(size), ToTensor()两个函数中分别添加im = np.array(im)[ :, :, np.newaxis]代码添加维度,否则会报错
T.Normalize(IMAGENET_MEAN, IMAGENET_STD)修改IMAGENET_MEAN, IMAGENET_STD参数的数值由三个通道改为单个通道数值

如果要进行验证集测试的话会遇到报错RuntimeError: Given groups=1, weight of size [16, 1, 3, 3], expected input[1, 3, 64, 64] to have 1 channels, but got 3 channels instead
将validator.py文件中__call__函数中

model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))
修改为
model.warmup(imgsz=(1 if pt else self.args.batch, 1, imgsz, imgsz))

http://www.ppmy.cn/news/47484.html

相关文章

“衰老标志物”重磅综述:细胞衰老、器官衰老、衰老时钟及其应用

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 随着人口老龄化程度不断加深,实现“健康老龄化(healthy aging)”已成为我国乃至世界迫切需要解决的重大社会和科学问题。据测算,我国60岁及…

上市公司碳排放测算数据(1992-2022年)

根据《温室气体核算体系》,企业的碳排放可以分为三个范围。 范围一是直接温室气体排放,产生于企业拥有或控制的排放源,例如企业拥有或控制的锅炉、熔炉、车辆等产生的燃烧排放;拥有或控制的工艺设备进行化工生产所产生的排放。 范…

如何优化Python网络爬虫,提高爬取速度?

目录 一、提升爬虫的速度二、并发和并行三、同步和异步四、多线程爬虫五、简单单线程爬虫 多线程简单的多线程爬虫实例使用Queue的多线程爬虫多进程爬虫使用multiprocessing的多进程爬虫最后 一、提升爬虫的速度 爬虫可以从获取网页、解析网页、存储数据来实现一些基本的。现在…

基于html+css的图片展示11

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

腾讯云4核8G轻量服务器12M支持多少访客同时在线?并发数怎么算?

腾讯云轻量4核8G12M轻量应用服务器支持多少人同时在线?通用型-4核8G-180G-2000G,2000GB月流量,系统盘为180GB SSD盘,12M公网带宽,下载速度峰值为1536KB/s,即1.5M/秒,假设网站内页平均大小为60KB…

成功上岸字节35K,技术4面+HR面,耗时20天,真是不容易

这次字节的面试,给我的感触很深,意识到基础的重要性。一共经历了五轮面试:技术4面+HR面。 下面看正文 本人自动专业毕业,压抑了五个多月,终于鼓起勇气,去字节面试,下面是我的面试过…

(排序11)排序的时间复杂度,空间复杂度,稳定性总结

图片总结 内排序时间复杂度总结 内部排序:数据元素全部放在内存中的排序。. 在内排序当中比较快的有希尔排序,堆排序,快速排序,归并排序,这四个排序的时间复杂度都是O(n*logn)。其中希尔排序的时间复杂度更加准确的来…

Windows实现在桌面上双击图标,自动进入到指定网址

功能实现步骤 创建一个快捷方式,右键点击桌面上的空白区域,选择“新建”->“快捷方式”。在弹出的“创建快捷方式”对话框中,输入你想要打开的网站的URL,例如 https://www.bing.com/?mktzh-cn&mktzh-CN ,然后…

轻松掌握Qt FTP 机制:实现高效文件传输

轻松掌握Qt FTP:实现高效文件传输 一、简介(Introduction)1.1 文件传输协议(FTP)Qt及其网络模块(Qt and its Network Module) QNetwork:二、QNetworkAccessManager上传实例(Qt FTP U…

Ubuntu中使用vscode+cmake进行编译调试

首先新建一个文件夹作为工作空间 mkdir test 进入工作空间文件夹&#xff0c;在vscode中打开 cd test code . 创建一个c文件 #include<iostream>using namespace std;int main(){int a 23;int b a3;for(int i 0; i<10; i){cout<<"hello vs code &a…

《程序员面试金典(第6版)》面试题 16.01. 交换数字(位运算符,异或性质)

题目描述 编写一个函数&#xff0c;不用临时变量&#xff0c;直接交换numbers [a, b]中a与b的值。 示例&#xff1a; 输入: numbers [1,2]输出: [2,1] 提示&#xff1a; numbers.length 2-2147483647 < numbers[i] < 2147483647 解题思路与代码 这道题不让使用额外…

linux_管道学习-pipe函数-管道的读写-fpathconf函数

接上一篇&#xff1a;linux_何为IPC-进程间常用的通信方式 今天来分享linux的管道学习&#xff0c;希望我的笔记能对大家有用&#xff0c;开始上菜&#xff1a; 目录 1.管道的概念&#xff1a;2.pipe函数3.管道的读写行为4.管道缓冲区大小5.管道的优劣 1.管道的概念&#xff1…

云通讯服务商有哪些?

随着语聊、视频通话、直播等行业的兴起&#xff0c;云通讯厂商的作用越来越凸显&#xff0c;解决画面卡顿、解决声音延迟以及基于互动领域更多的行业解决方案已经成为开发者和企业所需。 从长远来看&#xff0c;随着5G的不断普及&#xff0c;低延迟、高质量的网络环境不断催生线…

kafka--python

文章目录 1、kafka是什么2、docker上部署kafka3、在kafka容器内部署python&#xff0c;并跑通生产者-消费者简单代码4、最新接口4.1、kafka_config.py4.2、kafka_interface.py4.3、run.py4、测试 1、kafka是什么 Producer&#xff1a;即生产者&#xff0c;消息的产生者&#xf…

编译和引用so库

编译和引用so库 1.两种编译方式 ndk-build Android.mk Application.mkCMake CMakeList 2.Android.mk Application.mk (1)javac java文件的绝对路径 → 生成so库 (2)javah com.xxx.xxx.tesAdd → 生成头文件 (3) 修改头文件的后缀&#xff0c;并添加实现 (4)Applicat…

git教程

Git是目前最流行的分布式版本控制系统之一&#xff0c;它可以帮助开发者更好地管理代码和协作开发。以下是Git教程的一些内容&#xff1a; Git入门&#xff1a;介绍Git的基本概念、Git工作流程和Git常用命令。 Git分支&#xff1a;讲解Git分支的用法&#xff0c;包括新建分支、…

Flutter与Android开发:构建跨平台移动应用的新选择

Flutter与Android开发&#xff1a;构建跨平台移动应用的新选择 本文内容提纲如下&#xff1a; 介绍Flutter技术&#xff1a;Flutter是一种由Google推出的开源UI工具包&#xff0c;用于构建高性能、跨平台的移动应用。文章将介绍Flutter的基本概念、特点和优势&#xff0c;包括其…

Python面向对象详解(非常详细)

非常详细的讲解&#xff08;爆肝1w字&#xff09;&#x1f44f;&#x1f3fb;&#x1f44f;&#x1f3fb;&#x1f44f;&#x1f3fb; 零基础一样学得会&#x1f44c;&#x1f3fb; 干货满满不看后悔&#x1f44d;&#x1f44d;&#x1f44d; &#x1f4dd;个人主页→数据…

可能你已经刷了很多01背包的题,但是真的对01背包领悟透彻了吗?,看我这一篇,使君对01背包的理解更进一步【代码+图解+文字描述】

一.概念理解&#xff1a;什么是01背包 关于01背包的概念理解如下&#xff1a;01背包是在M件物品取出若干件放在空间为W的背包里&#xff0c;每件物品的体积为W1&#xff0c;W2至Wn&#xff0c;与之相对应的价值为P1,P2至Pn。001背包的约束条件是给定几种物品&#xff0c;每种物…

数组篇刷题总结

二分查找&#xff1a; 给定一个 n 个元素有序的&#xff08;升序&#xff09;整型数组 nums 和一个目标值 target &#xff0c;写一个函数搜索 nums 中的 target&#xff0c;如果目标值存在返回下标&#xff0c;否则返回 -1。 示例 1: 输入: nums [-1,0,3,5,9,12], target …