下载本文档
版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领
文档简介
第Python实现LeNet网络模型的训练及预测这部分代码是显示测试集当中前五张图片,运行后会显示5张拼接的图片
由于这个数据集的图片都比较小都是32x32的尺寸,有些可能也看的不太清楚,图中显示的是真实标签,注:显示图片的代码可能会这个报警(ClippinginputdatatothevalidrangeforimshowwithRGBdata([0…1]forfloatsor[0…255]forintegers).),警告解决的方法:将图片数组转成uint8类型即可,即plt.imshow(npimg.astype(‘uint8'),但是那样显示出来的图片会变,所以暂时可以先不用管。
(5).初始化模型
数据图片处理完了,下面就是我们的正式训练过程
net=LeNet()
#定义损失函数,nn.CrossEntropyLoss()自带softmax函数,所以模型的最后一层不需要softmax进行激活
loss_function=nn.CrossEntropyLoss()
#定义优化器,优化模型所有参数
optimizer=optim.Adam(net.parameters(),lr=0.001)
首先初始化LeNet网络,定义交叉熵损失函数,以及Adam优化器,关于注释写的,我们可以ctrl+鼠标左键查看CrossEntropyLoss(),翻到CrossEntropyLoss类,可以看到注释写的这个标准包含LogSoftmax函数,所以搭建LetNet模型的最后一层没有使用softmax激活函数
(6).训练模型及保存模型参数
forepochinrange(5):
#初始损失设置为0
running_loss=0
#循环训练集,从1开始
forstep,datainenumerate(train_loader,start=1):
inputs,labels=data
#优化器的梯度清零,每次循环都需要清零,否则梯度会无限叠加,相当于增加批次大小
optimizer.zero_grad()
#将图片数据输入模型中得到输出
outputs=net(inputs)
#传入预测值和真实值,计算当前损失值
loss=loss_function(outputs,labels)
#损失反向传播
loss.backward()
#进行梯度更新(更新W,b)
optimizer.step()
#计算该轮的总损失,因为loss是tensor类型,所以需要用item()取到值
running_loss+=loss.item()
#每500次进行日志的打印,对测试集进行测试
ifstep%500==0:
#torch.no_grad()就是上下文管理,测试时不需要梯度更新,不跟踪梯度
withtorch.no_grad():
#传入所有测试集图片进行预测
outputs=net(test_img)
#torch.max()中dim=1是因为结果为(batch,10)的形式,我们只需要取第二个维度的最大值,第二个维度是包含十个类别每个类别的概率的向量
#max这个函数返回[最大值,最大值索引],我们只需要取索引就行了,所以用[1]
predict_y=torch.max(outputs,dim=1)[1]
#(predict_y==test_label)相同返回True,不相等返回False,sum()对正确结果进行叠加,最后除测试集标签的总个数
#因为计算的变量都是tensor,所以需要用item()拿到取值
accuracy=(predict_y==test_label).sum().item()/test_label.size(0)
#running_loss/500是计算每一个step的loss,即每一步的损失
print('[%d,%5d]train_loss:%.3ftest_accuracy:%.3f'%
(epoch+1,step,running_loss/500,accuracy))
running_loss=0.0
print('FinishedTraining!')
save_path='lenet.pth'
#保存模型,字典形式
torch.save(net.state_dict(),save_path)
这段代码注释写的很清楚,大家仔细看就能看懂,流程不复杂,多看几遍就能理解,最后再对训练好的模型进行保存就好了(* ̄︶ ̄)
2.预测脚本
上面已经训练好了模型,得到了lenet.pth参数文件,预测就很简单了,可以去网上随便找一张数据集包含的类别图片,将模型参数文件载入模型,通过对图像进行一点处理,喂入模型即可,下面奉上代码:
importtorch
importnumpyasnp
importtorchvision.transformsastransforms
fromPILimportImage
frompytorch.lenet.modelimportLeNet
classes=('plane','car','bird','cat','deer',
'dog','frog','horse','ship','truck')
transforms=transforms.Compose(
#对数据图片调整大小
[transforms.Resize([32,32]),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
net=LeNet()
#加载预训练模型
net.load_state_dict(torch.load('lenet.pth'))
#网上随便找的猫的图片
img_path='../../Photo/cat2.jpg'
img=Image.open(img_path)
#图片的处理
img=transforms(img)
#增加一个维度,(channels,height,width)-------(batch,channels,height,width),pytorch要求必须输入这样的shape
img=torch.unsqueeze(img,dim=0)
withtorch.no_grad():
output=net(img)
#dim=1,只取[batch,10]中10个类别的那个维度,取预测结果的最大值索引,并转换为numpy类型
prediction1=torch.max(output,dim=1)[1].data.numpy()
#用softmax()预测出一个概率矩阵
prediction2=torch.softmax(output,dim=1)
#得到概率最大的值得索引
prediction2=np.argmax(predictio
温馨提示
- 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
- 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
- 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
- 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
- 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
- 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
- 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。
最新文档
- 八年级地理下册 6.4 长江三角洲地区-城市密集的区域教学设计 晋教版
- 人教部编版第1课 隋朝的统一与灭亡教学设计
- 顶棚构造教学设计中职专业课-建筑识图与构造-建筑类-土木建筑大类
- 2026年四川省南充市社区工作者招聘考试备考试题及答案解析
- 2026年浙江省湖州市社区工作者招聘考试参考题库及答案解析
- 2026年思茅地区社区工作者招聘考试备考题库及答案解析
- 时间像小马车教学设计小学音乐人音版五线谱一年级下册-人音版(五线谱)
- 山东省临清市高中数学 3.2 函数的奇偶性全套教案 新人教A版必修1
- 第一单元 丰富多彩的化学物质教学设计高中化学苏教版必修1-苏教版2004
- 第二单元第11课一、《艺术相框效果》教学设计 人教版初中信息技术七年级下册
- LNG液化天然气卸车标准作业流程
- 索尼微单相机A7 II(ILCE-7M2)使用说明书
- 三体系认证培训课件
- 2026年高考英语-2024年新课标II卷词汇清单
- 2025年机械设计与自动化测试题及答案
- (2024)电梯安全管理员考试题及参考答案
- 做自强不息的中国人+说课课件2024-2025学年统编版道德与法治七年级下册
- T/CECS 10235-2022绿色建材评价人造石
- 日常教学体例格式1-12:工学一体化课程标准校本转化建议- 工学一体化课程教案
- 陕西文化课件
- 《西南交通大学》课件
评论
0/150
提交评论