




版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報或認(rèn)領(lǐng)
文檔簡介
第Pytorch從0實(shí)現(xiàn)Transformer的實(shí)踐目錄摘要一、構(gòu)造數(shù)據(jù)1.1句子長度1.2生成句子1.3生成字典1.4得到向量化的句子二、位置編碼2.1計(jì)算括號內(nèi)的值2.2得到位置編碼三、多頭注意力3.1selfmask
摘要
Withthecontinuousdevelopmentoftimeseriesprediction,Transformer-likemodelshavegraduallyreplacedtraditionalmodelsinthefieldsofCVandNLPbyvirtueoftheirpowerfuladvantages.Amongthem,theInformerisfarsuperiortothetraditionalRNNmodelinlong-termprediction,andtheSwinTransformerissignificantlystrongerthanthetraditionalCNNmodelinimagerecognition.AdeepgraspofTransformerhasbecomeaninevitablerequirementinthefieldofartificialintelligence.ThisarticlewillusethePytorchframeworktoimplementthepositionencoding,multi-headattentionmechanism,self-mask,causalmaskandotherfunctionsinTransformer,andbuildaTransformernetworkfrom0.
隨著時序預(yù)測的不斷發(fā)展,Transformer類模型憑借強(qiáng)大的優(yōu)勢,在CV、NLP領(lǐng)域逐漸取代傳統(tǒng)模型。其中Informer在長時序預(yù)測上遠(yuǎn)超傳統(tǒng)的RNN模型,SwinTransformer在圖像識別上明顯強(qiáng)于傳統(tǒng)的CNN模型。深層次掌握Transformer已經(jīng)成為從事人工智能領(lǐng)域的必然要求。本文將用Pytorch框架,實(shí)現(xiàn)Transformer中的位置編碼、多頭注意力機(jī)制、自掩碼、因果掩碼等功能,從0搭建一個Transformer網(wǎng)絡(luò)。
一、構(gòu)造數(shù)據(jù)
1.1句子長度
#關(guān)于wordembedding,以序列建模為例
#輸入句子有兩個,第一個長度為2,第二個長度為4
src_len=torch.tensor([2,4]).to(32)
#目標(biāo)句子有兩個。第一個長度為4,第二個長度為3
tgt_len=torch.tensor([4,3]).to(32)
print(src_len)
print(tgt_len)
輸入句子(src_len)有兩個,第一個長度為2,第二個長度為4
目標(biāo)句子(tgt_len)有兩個。第一個長度為4,第二個長度為3
1.2生成句子
用隨機(jī)數(shù)生成句子,用0填充空白位置,保持所有句子長度一致
src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_src_words,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])
tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_tgt_words,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])
print(src_seq)
print(tgt_seq)
src_seq為輸入的兩個句子,tgt_seq為輸出的兩個句子。
為什么句子是數(shù)字?在做中英文翻譯時,每個中文或英文對應(yīng)的也是一個數(shù)字,只有這樣才便于處理。
1.3生成字典
在該字典中,總共有8個字(行),每個字對應(yīng)8維向量(做了簡化了的)。注意在實(shí)際應(yīng)用中,應(yīng)當(dāng)有幾十萬個字,每個字可能有512個維度。
#構(gòu)造wordembedding
src_embedding_table=nn.Embedding(9,model_dim)
tgt_embedding_table=nn.Embedding(9,model_dim)
#輸入單詞的字典
print(src_embedding_table)
#目標(biāo)單詞的字典
print(tgt_embedding_table)
字典中,需要留一個維度給classtoken,故是9行。
1.4得到向量化的句子
通過字典取出1.2中得到的句子
#得到向量化的句子
src_embedding=src_embedding_table(src_seq)
tgt_embedding=tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)
該階段總程序
importtorch
#句子長度
src_len=torch.tensor([2,4]).to(32)
tgt_len=torch.tensor([4,3]).to(32)
#構(gòu)造句子,用0填充空白處
src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])
tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])
#構(gòu)造字典
src_embedding_table=nn.Embedding(9,8)
tgt_embedding_table=nn.Embedding(9,8)
#得到向量化的句子
src_embedding=src_embedding_table(src_seq)
tgt_embedding=tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)
二、位置編碼
位置編碼是transformer的一個重點(diǎn),通過加入transformer位置編碼,代替了傳統(tǒng)RNN的時序信息,增強(qiáng)了模型的并發(fā)度。位置編碼的公式如下:(其中pos代表行,i代表列)
2.1計(jì)算括號內(nèi)的值
#得到分子pos的值
pos_mat=torch.arange(4).reshape((-1,1))
#得到分母值
i_mat=torch.pow(10000,torch.arange(0,8,2).reshape((1,-1))/8)
print(pos_mat)
print(i_mat)
2.2得到位置編碼
#初始化位置編碼矩陣
pe_embedding_table=torch.zeros(4,8)
#得到偶數(shù)行位置編碼
pe_embedding_table[:,0::2]=torch.sin(pos_mat/i_mat)
#得到奇數(shù)行位置編碼
pe_embedding_table[:,1::2]=torch.cos(pos_mat/i_mat)
pe_embedding=nn.Embedding(4,8)
#設(shè)置位置編碼不可更新參數(shù)
pe_embedding.weight=nn.Parameter(pe_embedding_table,requires_grad=False)
print(pe_embedding.weight)
三、多頭注意力
3.1selfmask
有些位置是空白用0填充的,訓(xùn)練時不希望被這些位置所影響,那么就需要用到selfmask。selfmask的原理是令這些位置的值為無窮小,經(jīng)過softmax后,這些值會變?yōu)?,不會再影響結(jié)果。
3.1.1得到有效位置矩陣
#得到有效位置矩陣
vaild_encoder_pos=torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0)forLinsrc_len]),2)
valid_encoder_pos_matrix=torch.bmm(vaild_encoder_pos,vaild_encoder_pos.transpose(1,2))
print(valid_encoder_pos_matrix)
3.1.2得到無效位置矩陣
invalid_encoder_pos_matrix=1-valid_encoder_pos_matrix
mask_encoder_self_attention=invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention)
True代表需要對該位置mask
3.1.3得到mask矩陣
用極小數(shù)填充需要被mask的位置
#初始化mask矩陣
score=torch.randn(2,max(
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
- 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 光學(xué)軟件測試題及答案
- 美術(shù)培訓(xùn)講座
- 2025年 阜陽臨泉城關(guān)街道桃花源幼兒園教師招聘考試筆試試卷附答案
- 2025年 北京公務(wù)員考試筆試考試試卷附答案
- 2025年主題團(tuán)日活動策劃與實(shí)施
- 小學(xué)交通教育課件
- 左膝關(guān)節(jié)置換術(shù)后護(hù)理
- 2025年中國墨西哥胡椒鹽行業(yè)市場全景分析及前景機(jī)遇研判報告
- 子宮畸形超聲分類及診斷
- 支氣管肺炎相關(guān)疾病知識
- 《大學(xué)英語》課件-UNIT 3 In the workplace
- 路燈安全生產(chǎn)培訓(xùn)
- 疑難病例討論制度流程
- 痛經(jīng)課件完整版本
- 2025高考數(shù)學(xué)考點(diǎn)鞏固卷01集合與常用邏輯用語(7大考點(diǎn))【含答案】
- 廣西南寧市(2024年-2025年小學(xué)六年級語文)統(tǒng)編版小升初真題((上下)學(xué)期)試卷及答案
- 旅游景區(qū)管理制度完整匯編
- 人教小學(xué)英語一起點(diǎn)新起點(diǎn)sl版6上 單元知識點(diǎn)歸納總結(jié)
- 《毛澤東思想和中國特色社會主義理論體系概論》微課之課件-1.2.2毛澤東思想活的靈魂
- 鈣化性岡上肌腱炎病因介紹
- UL1561標(biāo)準(zhǔn)中文版-2019變壓器UL中文版標(biāo)準(zhǔn)
評論
0/150
提交評論