




版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领
文档简介
第Pytorch从0实现Transformer的实践目录摘要一、构造数据1.1句子长度1.2生成句子1.3生成字典1.4得到向量化的句子二、位置编码2.1计算括号内的值2.2得到位置编码三、多头注意力3.1selfmask
摘要
Withthecontinuousdevelopmentoftimeseriesprediction,Transformer-likemodelshavegraduallyreplacedtraditionalmodelsinthefieldsofCVandNLPbyvirtueoftheirpowerfuladvantages.Amongthem,theInformerisfarsuperiortothetraditionalRNNmodelinlong-termprediction,andtheSwinTransformerissignificantlystrongerthanthetraditionalCNNmodelinimagerecognition.AdeepgraspofTransformerhasbecomeaninevitablerequirementinthefieldofartificialintelligence.ThisarticlewillusethePytorchframeworktoimplementthepositionencoding,multi-headattentionmechanism,self-mask,causalmaskandotherfunctionsinTransformer,andbuildaTransformernetworkfrom0.
随着时序预测的不断发展,Transformer类模型凭借强大的优势,在CV、NLP领域逐渐取代传统模型。其中Informer在长时序预测上远超传统的RNN模型,SwinTransformer在图像识别上明显强于传统的CNN模型。深层次掌握Transformer已经成为从事人工智能领域的必然要求。本文将用Pytorch框架,实现Transformer中的位置编码、多头注意力机制、自掩码、因果掩码等功能,从0搭建一个Transformer网络。
一、构造数据
1.1句子长度
#关于wordembedding,以序列建模为例
#输入句子有两个,第一个长度为2,第二个长度为4
src_len=torch.tensor([2,4]).to(32)
#目标句子有两个。第一个长度为4,第二个长度为3
tgt_len=torch.tensor([4,3]).to(32)
print(src_len)
print(tgt_len)
输入句子(src_len)有两个,第一个长度为2,第二个长度为4
目标句子(tgt_len)有两个。第一个长度为4,第二个长度为3
1.2生成句子
用随机数生成句子,用0填充空白位置,保持所有句子长度一致
src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_src_words,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])
tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_tgt_words,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])
print(src_seq)
print(tgt_seq)
src_seq为输入的两个句子,tgt_seq为输出的两个句子。
为什么句子是数字?在做中英文翻译时,每个中文或英文对应的也是一个数字,只有这样才便于处理。
1.3生成字典
在该字典中,总共有8个字(行),每个字对应8维向量(做了简化了的)。注意在实际应用中,应当有几十万个字,每个字可能有512个维度。
#构造wordembedding
src_embedding_table=nn.Embedding(9,model_dim)
tgt_embedding_table=nn.Embedding(9,model_dim)
#输入单词的字典
print(src_embedding_table)
#目标单词的字典
print(tgt_embedding_table)
字典中,需要留一个维度给classtoken,故是9行。
1.4得到向量化的句子
通过字典取出1.2中得到的句子
#得到向量化的句子
src_embedding=src_embedding_table(src_seq)
tgt_embedding=tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)
该阶段总程序
importtorch
#句子长度
src_len=torch.tensor([2,4]).to(32)
tgt_len=torch.tensor([4,3]).to(32)
#构造句子,用0填充空白处
src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])
tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])
#构造字典
src_embedding_table=nn.Embedding(9,8)
tgt_embedding_table=nn.Embedding(9,8)
#得到向量化的句子
src_embedding=src_embedding_table(src_seq)
tgt_embedding=tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)
二、位置编码
位置编码是transformer的一个重点,通过加入transformer位置编码,代替了传统RNN的时序信息,增强了模型的并发度。位置编码的公式如下:(其中pos代表行,i代表列)
2.1计算括号内的值
#得到分子pos的值
pos_mat=torch.arange(4).reshape((-1,1))
#得到分母值
i_mat=torch.pow(10000,torch.arange(0,8,2).reshape((1,-1))/8)
print(pos_mat)
print(i_mat)
2.2得到位置编码
#初始化位置编码矩阵
pe_embedding_table=torch.zeros(4,8)
#得到偶数行位置编码
pe_embedding_table[:,0::2]=torch.sin(pos_mat/i_mat)
#得到奇数行位置编码
pe_embedding_table[:,1::2]=torch.cos(pos_mat/i_mat)
pe_embedding=nn.Embedding(4,8)
#设置位置编码不可更新参数
pe_embedding.weight=nn.Parameter(pe_embedding_table,requires_grad=False)
print(pe_embedding.weight)
三、多头注意力
3.1selfmask
有些位置是空白用0填充的,训练时不希望被这些位置所影响,那么就需要用到selfmask。selfmask的原理是令这些位置的值为无穷小,经过softmax后,这些值会变为0,不会再影响结果。
3.1.1得到有效位置矩阵
#得到有效位置矩阵
vaild_encoder_pos=torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0)forLinsrc_len]),2)
valid_encoder_pos_matrix=torch.bmm(vaild_encoder_pos,vaild_encoder_pos.transpose(1,2))
print(valid_encoder_pos_matrix)
3.1.2得到无效位置矩阵
invalid_encoder_pos_matrix=1-valid_encoder_pos_matrix
mask_encoder_self_attention=invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention)
True代表需要对该位置mask
3.1.3得到mask矩阵
用极小数填充需要被mask的位置
#初始化mask矩阵
score=torch.randn(2,max(
温馨提示
- 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
- 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
- 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
- 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
- 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
- 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
- 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。
最新文档
- 安徽省阜阳市颍州区2025届数学三年级第一学期期末质量跟踪监视模拟试题含解析
- 2025届西藏山南地区扎囊县数学三年级第一学期期末模拟试题含解析
- 行政管理的公共关系学备考试题及答案
- 2022 年中级会计师考试《中级经济法》真题及解析(9月5日)
- 剧组协调员助理场记聘用合同
- 长期公寓租赁合同
- 中级经济师考试对行业发展的影响与试题及答案
- 农民信息技术应用服务合同
- 知识产权转让与保密协议细节展开说明文档
- 心理学应用知识练习题
- 法律法规合规性评价记录表
- 初中历史资本主义制度的初步确立 作业设计
- 能源英语面面观 知到智慧树网课答案
- 电脑时代需要练字辩论材料
- MOOC 职业生涯开发与管理-南京邮电大学 中国大学慕课答案
- 中国书法艺术智慧树知到期末考试答案2024年
- 2024年4月自考00015英语(二)试题
- 上汽大众电子说明书
- 数学建模与系统仿真智慧树知到期末考试答案2024年
- 足球鞋推广方案
- 论三农工作培训课件
评论
0/150
提交评论