Pytorch从0实现Transformer的实践_第1页
Pytorch从0实现Transformer的实践_第2页
Pytorch从0实现Transformer的实践_第3页
Pytorch从0实现Transformer的实践_第4页
Pytorch从0实现Transformer的实践_第5页
已阅读5页,还剩2页未读 继续免费阅读

下载本文档

版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领

文档简介

第Pytorch从0实现Transformer的实践目录摘要一、构造数据1.1句子长度1.2生成句子1.3生成字典1.4得到向量化的句子二、位置编码2.1计算括号内的值2.2得到位置编码三、多头注意力3.1selfmask

摘要

Withthecontinuousdevelopmentoftimeseriesprediction,Transformer-likemodelshavegraduallyreplacedtraditionalmodelsinthefieldsofCVandNLPbyvirtueoftheirpowerfuladvantages.Amongthem,theInformerisfarsuperiortothetraditionalRNNmodelinlong-termprediction,andtheSwinTransformerissignificantlystrongerthanthetraditionalCNNmodelinimagerecognition.AdeepgraspofTransformerhasbecomeaninevitablerequirementinthefieldofartificialintelligence.ThisarticlewillusethePytorchframeworktoimplementthepositionencoding,multi-headattentionmechanism,self-mask,causalmaskandotherfunctionsinTransformer,andbuildaTransformernetworkfrom0.

随着时序预测的不断发展,Transformer类模型凭借强大的优势,在CV、NLP领域逐渐取代传统模型。其中Informer在长时序预测上远超传统的RNN模型,SwinTransformer在图像识别上明显强于传统的CNN模型。深层次掌握Transformer已经成为从事人工智能领域的必然要求。本文将用Pytorch框架,实现Transformer中的位置编码、多头注意力机制、自掩码、因果掩码等功能,从0搭建一个Transformer网络。

一、构造数据

1.1句子长度

#关于wordembedding,以序列建模为例

#输入句子有两个,第一个长度为2,第二个长度为4

src_len=torch.tensor([2,4]).to(32)

#目标句子有两个。第一个长度为4,第二个长度为3

tgt_len=torch.tensor([4,3]).to(32)

print(src_len)

print(tgt_len)

输入句子(src_len)有两个,第一个长度为2,第二个长度为4

目标句子(tgt_len)有两个。第一个长度为4,第二个长度为3

1.2生成句子

用随机数生成句子,用0填充空白位置,保持所有句子长度一致

src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_src_words,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])

tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_tgt_words,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])

print(src_seq)

print(tgt_seq)

src_seq为输入的两个句子,tgt_seq为输出的两个句子。

为什么句子是数字?在做中英文翻译时,每个中文或英文对应的也是一个数字,只有这样才便于处理。

1.3生成字典

在该字典中,总共有8个字(行),每个字对应8维向量(做了简化了的)。注意在实际应用中,应当有几十万个字,每个字可能有512个维度。

#构造wordembedding

src_embedding_table=nn.Embedding(9,model_dim)

tgt_embedding_table=nn.Embedding(9,model_dim)

#输入单词的字典

print(src_embedding_table)

#目标单词的字典

print(tgt_embedding_table)

字典中,需要留一个维度给classtoken,故是9行。

1.4得到向量化的句子

通过字典取出1.2中得到的句子

#得到向量化的句子

src_embedding=src_embedding_table(src_seq)

tgt_embedding=tgt_embedding_table(tgt_seq)

print(src_embedding)

print(tgt_embedding)

该阶段总程序

importtorch

#句子长度

src_len=torch.tensor([2,4]).to(32)

tgt_len=torch.tensor([4,3]).to(32)

#构造句子,用0填充空白处

src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])

tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])

#构造字典

src_embedding_table=nn.Embedding(9,8)

tgt_embedding_table=nn.Embedding(9,8)

#得到向量化的句子

src_embedding=src_embedding_table(src_seq)

tgt_embedding=tgt_embedding_table(tgt_seq)

print(src_embedding)

print(tgt_embedding)

二、位置编码

位置编码是transformer的一个重点,通过加入transformer位置编码,代替了传统RNN的时序信息,增强了模型的并发度。位置编码的公式如下:(其中pos代表行,i代表列)

2.1计算括号内的值

#得到分子pos的值

pos_mat=torch.arange(4).reshape((-1,1))

#得到分母值

i_mat=torch.pow(10000,torch.arange(0,8,2).reshape((1,-1))/8)

print(pos_mat)

print(i_mat)

2.2得到位置编码

#初始化位置编码矩阵

pe_embedding_table=torch.zeros(4,8)

#得到偶数行位置编码

pe_embedding_table[:,0::2]=torch.sin(pos_mat/i_mat)

#得到奇数行位置编码

pe_embedding_table[:,1::2]=torch.cos(pos_mat/i_mat)

pe_embedding=nn.Embedding(4,8)

#设置位置编码不可更新参数

pe_embedding.weight=nn.Parameter(pe_embedding_table,requires_grad=False)

print(pe_embedding.weight)

三、多头注意力

3.1selfmask

有些位置是空白用0填充的,训练时不希望被这些位置所影响,那么就需要用到selfmask。selfmask的原理是令这些位置的值为无穷小,经过softmax后,这些值会变为0,不会再影响结果。

3.1.1得到有效位置矩阵

#得到有效位置矩阵

vaild_encoder_pos=torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0)forLinsrc_len]),2)

valid_encoder_pos_matrix=torch.bmm(vaild_encoder_pos,vaild_encoder_pos.transpose(1,2))

print(valid_encoder_pos_matrix)

3.1.2得到无效位置矩阵

invalid_encoder_pos_matrix=1-valid_encoder_pos_matrix

mask_encoder_self_attention=invalid_encoder_pos_matrix.to(torch.bool)

print(mask_encoder_self_attention)

True代表需要对该位置mask

3.1.3得到mask矩阵

用极小数填充需要被mask的位置

#初始化mask矩阵

score=torch.randn(2,max(

温馨提示

  • 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
  • 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
  • 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
  • 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
  • 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
  • 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
  • 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。

评论

0/150

提交评论