案例24-第六章 基于联邦学习的MNIST手写数字识别_第1页
案例24-第六章 基于联邦学习的MNIST手写数字识别_第2页
案例24-第六章 基于联邦学习的MNIST手写数字识别_第3页
案例24-第六章 基于联邦学习的MNIST手写数字识别_第4页
案例24-第六章 基于联邦学习的MNIST手写数字识别_第5页
已阅读5页,还剩2页未读 继续免费阅读

下载本文档

版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领

文档简介

案例24——基于联邦学习的MNIST手写数字识别本案例探讨了联邦学习在解决数据孤岛问题中的应用,特别是在MNIST手写数字识别任务中。通过Pytorch实现了一个基于联邦学习的多机构合作框架,展示了在不共享原始数据的情况下,多个机构如何共同提升模型性能。实验结果显示,随着联邦轮数的增加,各机构的模型准确率显著提高。1.案例背景在传统的集中式训练模式中,所有数据需要汇总到同一位置进行建模,这在实际应用中会面临数据隐私、合规性和传输成本等问题。MNIST手写数字数据集虽然是公开数据,但在本案例中被人为划分到不同机构,模拟真实场景下数据分散存储的情况。通过引入联邦学习框架,各机构在本地训练模型并仅共享模型参数,既保证了数据安全,又能够在全局范围内不断优化模型性能。2.数据集与数据预处理MNIST数据集是计算机视觉领域中广泛使用的标准基准数据集,专门用于手写数字识别任务。该数据集包含60,000张用于训练的灰度图像和10,000张用于测试的灰度图像,每张图像大小为28×28像素,展示了从0到9的十个数字类别。由于图像尺寸适中且样本数量充足,MNIST数据集成为评估和比较各种机器学习和深度学习算法性能的理想选择。其数据格式统一且标注准确,便于在分布式或联邦学习环境中进行数据分割和模型训练,如下图1所示:图1手写数字本案例中,为模拟多机构协同训练场景,将MNIST训练数据集划分为三个独立子集,分别对应机构A、B和C。具体采用PyTorch中的Subset函数,从原始训练集中各选取1000条样本,构建三个均衡的数据子集。随后,将各子集封装进DataLoader中,便于模型训练时进行批量数据读取。这里设置的批量大小(batchsize)为整个训练子集的大小,即1000,意味着每个训练周期中,模型将一次性处理所有数据样本,保证训练过程的完整性和稳定性。由于数据一次性载入,无需进行随机打乱(shuffle设置为False),确保数据的顺序一致性,便于实验对比。该方法有效模拟了分布式数据环境下,各机构独立持有且不共享原始数据的实际情况,为后续联邦学习模型训练奠定数据基础。数据加载和划分代码如下:train_set=torchvision.datasets.MNIST(root="./data",train=True,transform=transforms.ToTensor(),download=True)train_set_A=Subset(train_set,range(0,1000))train_set_B=Subset(train_set,range(1000,2000))train_set_C=Subset(train_set,range(2000,3000))train_loader_A=dataloader.DataLoader(dataset=train_set_A,batch_size=1000,shuffle=False)train_loader_B=dataloader.DataLoader(dataset=train_set_B,batch_size=1000,shuffle=False)train_loader_C=dataloader.DataLoader(dataset=train_set_C,batch_size=1000,shuffle=False)test_set=torchvision.datasets.MNIST(root="./data",train=False,transform=transforms.ToTensor(),download=True)test_set=Subset(test_set,range(0,2000))test_loader=dataloader.DataLoader(dataset=test_set,shuffle=True)3.训练与测试本案例包含两类训练测试流程,代码如下:deftrain_and_test_1(train_loader,test_loader):classNeuralNet(nn.Module):def__init__(self,input_num,hidden_num,output_num):super(NeuralNet,self).__init__()self.fc1=nn.Linear(input_num,hidden_num)#服从正态分布的权重wself.fc2=nn.Linear(hidden_num,output_num)nn.init.normal_(self.fc1.weight)nn.init.normal_(self.fc2.weight)nn.init.constant_(self.fc1.bias,val=0)#初始化bias为0nn.init.constant_(self.fc2.bias,val=0)self.relu=nn.ReLU()#Relu激励函数defforward(self,x):x=self.fc1(x)x=self.relu(x)y=self.fc2(x)returnyepoches=20#迭代20轮lr=0.01#学习率,即步长input_num=784hidden_num=12output_num=10device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")

model=NeuralNet(input_num,hidden_num,output_num)model.to(device)loss_func=nn.CrossEntropyLoss()#损失函数的类型:交叉熵损失函数optimizer=optim.Adam(model.parameters(),lr=lr)#Adam优化,也可以用SGD随机梯度下降法#optimizer=optim.SGD(model.parameters(),lr=lr)forepochinrange(epoches):flag=0forimages,labelsintrain_loader:images=images.reshape(-1,28*28).to(device)labels=labels.to(device)output=model(images)

loss=loss_func(output,labels)optimizer.zero_grad()loss.backward()#误差反向传播,计算参数更新值optimizer.step()#将参数更新值施加到net的parameters上#以下两步可以看每轮损失函数具体的变化情况#if(flag+1)%10==0:#print('Epoch[{}/{}],Loss:{:.4f}'.format(epoch+1,epoches,loss.item()))flag+=1params=list(d_parameters())#获取模型参数#测试,评估准确率correct=0total=0forimages,labelsintest_loader:images=images.reshape(-1,28*28).to(device)labels=labels.to(device)output=model(images)values,predicte=torch.max(output,1)#0是每列的最大值,1是每行的最大值total+=labels.size(0)#predicte==labels返回每张图片的布尔类型correct+=(predicte==labels).sum().item()print("Theaccuracyoftotal{}images:{}%".format(total,100*correct/total))returnparamsdeftrain_and_test_2(train_loader,test_loader,com_para_fc1,com_para_fc2):classNeuralNet(nn.Module):def__init__(self,input_num,hidden_num,output_num,com_para_fc1,com_para_fc2):super(NeuralNet,self).__init__()self.fc1=nn.Linear(input_num,hidden_num)self.fc2=nn.Linear(hidden_num,output_num)self.fc1.weight=Parameter(com_para_fc1)self.fc2.weight=Parameter(com_para_fc2)nn.init.constant_(self.fc1.bias,val=0)nn.init.constant_(self.fc2.bias,val=0)self.relu=nn.ReLU()defforward(self,x):x=self.fc1(x)x=self.relu(x)y=self.fc2(x)returnyepoches=20lr=0.01input_num=784hidden_num=12output_num=10device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")model=NeuralNet(input_num,hidden_num,output_num,com_para_fc1,com_para_fc2)model.to(device)loss_func=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=lr)#optimizer=optim.SGD(model.parameters(),lr=lr)forepochinrange(epoches):flag=0forimages,labelsintrain_loader:#(images,labels)=dataimages=images.reshape(-1,28*28).to(device)labels=labels.to(device)output=model(images)loss=loss_func(output,labels)optimizer.zero_grad()loss.backward()optimizer.step()

#if(flag+1)%10==0:#print('Epoch[{}/{}],Loss:{:.4f}'.format(epoch+1,epoches,loss.item()))flag+=1params=list(d_parameters())#gettheindexbydebuging

correct=0total=0forimages,labelsintest_loader:images=images.reshape(-1,28*28).to(device)labels=labels.to(device)output=model(images)values,predicte=torch.max(output,1)total+=labels.size(0)correct+=(predicte==labels).sum().item()print("Theaccuracyoftotal{}images:{}%".format(total,100*correct/total))returnparamsdefcombine_params(para_A,para_B,para_C):fc1_wA=para_A[0][1].datafc1_wB=para_B[0][1].datafc1_wC=para_C[0][1].datafc2_wA=para_A[2][1].datafc2_wB=para_B[2][1].datafc2_wC=para_C[2][1].data

com_para_fc1=(fc1_wA+fc1_wB+fc1_wC)/3com_para_fc2=(fc2_wA+fc2_wB+fc2_wC)/3returncom_para_fc1,com_para_fc2para_A=train_and_test_1(train_loader_A,test_loader)para_B=train_and_test_1(train_loader_B,test_loader)para_C=train_and_test_1(train_loader_C,test_loader)foriinrange(10):print("The{}roundtobefederated!!!".format(i+1))com_para_fc1,com_para_fc2=combine_params(para_A,para_B,para_C)para_A=train_and_test_2(train_loader_A,test_loader,com_para_fc1,com_para_fc2)para_B=train_and_test_2(train_loader_B,test_loader,com_para_fc1,com_para_fc2)para_C=train_and_test_2(train_loader_C,test_loader,com_para_fc1

温馨提示

  • 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
  • 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
  • 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
  • 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
  • 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
  • 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
  • 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。

评论

0/150

提交评论