深度學(xué)習(xí)案例教程 課件 第7章 文本翻譯_第1頁
深度學(xué)習(xí)案例教程 課件 第7章 文本翻譯_第2頁
深度學(xué)習(xí)案例教程 課件 第7章 文本翻譯_第3頁
深度學(xué)習(xí)案例教程 課件 第7章 文本翻譯_第4頁
深度學(xué)習(xí)案例教程 課件 第7章 文本翻譯_第5頁
已閱讀5頁,還剩61頁未讀 繼續(xù)免費(fèi)閱讀

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報或認(rèn)領(lǐng)

文檔簡介

文本翻譯第七章01理解深度學(xué)習(xí)在文本翻譯中的應(yīng)用02掌握批處理的概念和應(yīng)用03理解BatchNormalization的原理和作用04

熟悉Seq2Seq網(wǎng)絡(luò)和注意力機(jī)制學(xué)習(xí)目標(biāo)CONTENTS05

掌握文本翻譯模型的搭建和訓(xùn)練01培養(yǎng)學(xué)習(xí)深度學(xué)習(xí)框架和模型設(shè)計(jì)的能力02培養(yǎng)解模型搭建和調(diào)優(yōu)能力03提高創(chuàng)新能力04培養(yǎng)團(tuán)隊(duì)合作和協(xié)作能力素質(zhì)目標(biāo)CONTENTS構(gòu)建翻譯模型訓(xùn)練翻譯模型實(shí)踐任務(wù)準(zhǔn)備翻譯數(shù)據(jù)集評估翻譯模型第一節(jié)

學(xué)習(xí)情景假設(shè)你是一家國際旅行社的員工,你的工作是幫助客戶安排旅行行程并提供相關(guān)信息。任務(wù)需求描述第一節(jié)

學(xué)習(xí)情景你需要利用深度學(xué)習(xí)模型進(jìn)行文本翻譯,將中文的旅行信息翻譯成英文。具體而言,你需要搭建一個基于Seq2Seq網(wǎng)絡(luò)和注意力機(jī)制的文本翻譯模型,使用PyTorch框架進(jìn)行模型的搭建和訓(xùn)練。通過訓(xùn)練這個模型,你將能夠?qū)⑤斎氲闹形奈谋巨D(zhuǎn)化為相應(yīng)的英文翻譯文本。任務(wù)需求描述第一節(jié)

學(xué)習(xí)情景在學(xué)習(xí)過程中,你需要掌握批處理的概念和使用方法,了解BatchNormalization的原理和應(yīng)用,理解Seq2Seq網(wǎng)絡(luò)的結(jié)構(gòu)和工作原理,以及注意力機(jī)制的作用和實(shí)現(xiàn)方法。任務(wù)需求描述批處理在文本翻譯任務(wù)中,使用批處理可以有效地處理大量的文本數(shù)據(jù)。通過將數(shù)據(jù)分成小批次進(jìn)行處理,可以提高訓(xùn)練效率和模型的穩(wěn)定性。批處理還能夠充分利用硬件資源,加速模型訓(xùn)練過程。BatchNormalization是一種常用的正則化技術(shù),用于加速模型的訓(xùn)練和提高模型的穩(wěn)定性。在文本翻譯任務(wù)中,可以將BatchNormalization應(yīng)用于神經(jīng)網(wǎng)絡(luò)的隱藏層,使得網(wǎng)絡(luò)更易于訓(xùn)練并減少模型的過擬合。Seq2Seq網(wǎng)絡(luò)是一種用于序列到序列(sequence-to-sequence)任務(wù)的神經(jīng)網(wǎng)絡(luò)模型,廣泛應(yīng)用于文本翻譯、語音識別等領(lǐng)域。在文本翻譯任務(wù)中,Seq2Seq網(wǎng)絡(luò)可以將輸入的英文文本序列轉(zhuǎn)化為相應(yīng)的中文文本序列,實(shí)現(xiàn)文本的翻譯功能學(xué)習(xí)情景-技術(shù)分析注意力機(jī)制注意力機(jī)制是Seq2Seq網(wǎng)絡(luò)中的關(guān)鍵組成部分,用于處理長序列的信息傳遞和對齊問題。在文本翻譯任務(wù)中,注意力機(jī)制可以幫助模型更好地理解和翻譯輸入文本的內(nèi)容,提高翻譯的準(zhǔn)確性和流暢度。第二節(jié)

批處理在深度學(xué)習(xí)任務(wù)中,批處理(BatchProcessing)是一種重要的技術(shù),用于有效地處理大量的數(shù)據(jù)并進(jìn)行模型訓(xùn)練。下面將詳細(xì)介紹批處理的背景、原理、應(yīng)用以及如何使用。批處理的定義第二節(jié)

批處理批處理是指將一定數(shù)量的樣本一起輸入到模型中進(jìn)行前向傳播和反向傳播的過程。通常,一個批次由多個樣本組成,每個樣本都是一個輸入特征和對應(yīng)的標(biāo)簽。通過批處理,模型可以根據(jù)每個批次的誤差進(jìn)行參數(shù)更新,從而逐漸優(yōu)化模型。批處理的定義第二節(jié)

批處理批處理廣泛應(yīng)用于深度學(xué)習(xí)任務(wù)的訓(xùn)練階段,包括圖像分類、目標(biāo)檢測、語音識別等。在這些任務(wù)中,大量的數(shù)據(jù)被劃分為小批次,并通過反向傳播算法更新模型的參數(shù),從而實(shí)現(xiàn)模型的訓(xùn)練和優(yōu)化。批處理的定義將數(shù)據(jù)集劃分為小批次,并對數(shù)據(jù)進(jìn)行預(yù)處理(如歸一化、數(shù)據(jù)增強(qiáng)等)。數(shù)據(jù)準(zhǔn)備步驟對每個批次進(jìn)行前向傳播和反向傳播,計(jì)算損失函數(shù)并更新模型參數(shù)。批處理迭代重復(fù)批處理迭代的過程,直到完成所有的訓(xùn)練輪次或達(dá)到停止條件。訓(xùn)練迭代批處理

批處理importtorchfromtorch.utils.dataimportDataLoader#加載數(shù)據(jù)集并進(jìn)行預(yù)處理dataset=YourDataset(...)preprocess=YourPreprocessing()#定義批大小和數(shù)據(jù)加載器batch_size=32data_loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)代碼實(shí)現(xiàn)批處理#模型定義和訓(xùn)練model=YourModel(...)optimizer=torch.optim.Adam(model.parameters(),lr=0.001)forepochinrange(num_epochs):forbatch_samples,batch_labelsindata_loader:outputs=model(batch_samples)loss=compute_loss(outputs,batch_labels)optimizer.zero_grad()loss.backward()optimizer.step()第三節(jié)

BatchNormalization批歸一化(BatchNormalization)是一種常用的深度學(xué)習(xí)技術(shù),用于加快神經(jīng)網(wǎng)絡(luò)的訓(xùn)練速度并提高模型的性能和穩(wěn)定性。?批歸一化是一種在神經(jīng)網(wǎng)絡(luò)中對輸入進(jìn)行規(guī)范化的技術(shù)?批歸一化可以應(yīng)用于卷積層、全連接層等神經(jīng)網(wǎng)絡(luò)中的任意層BatchNormalization的定義第三節(jié)

BatchNormalization為什么需要批歸一化??梯度消失與梯度爆炸問題?內(nèi)部協(xié)變量偏移BatchNormalization的定義BatchNormalization在深度學(xué)習(xí)框架中,批歸一化通常以層的形式提供,可以直接添加到神經(jīng)網(wǎng)絡(luò)的定義中。在每個批次的前向傳播過程中,批歸一化層會計(jì)算批次數(shù)據(jù)的均值和標(biāo)準(zhǔn)差,并對輸入進(jìn)行標(biāo)準(zhǔn)化。標(biāo)準(zhǔn)化后的數(shù)據(jù)會通過激活函數(shù)傳遞給下一層進(jìn)行進(jìn)一步的計(jì)算。在訓(xùn)練過程中,批歸一化層會維護(hù)一個移動平均值,用于在測試時對輸入進(jìn)行標(biāo)準(zhǔn)化。下面是一個使用PyTorch框架進(jìn)行批歸一化的示例代碼:第四節(jié)

Seq2Seq網(wǎng)絡(luò)Seq2Seq網(wǎng)絡(luò)(Sequence-to-Sequence網(wǎng)絡(luò))是一種用于處理序列數(shù)據(jù)的神經(jīng)網(wǎng)絡(luò)模型,它由編碼器(Encoder)和解碼器(Decoder)組成,常用于機(jī)器翻譯、對話生成等任務(wù)。Seq2Seq網(wǎng)絡(luò)的定義將輸入序列(源語言句子)編碼為固定長度的上下文向量,捕捉輸入序列的語義信息。常用的編碼器模型有循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)和長短期記憶網(wǎng)絡(luò)(LSTM)。編碼器(Encoder)原理結(jié)構(gòu)將上下文向量作為初始狀態(tài),逐步生成輸出序列(目標(biāo)語言句子)。解碼器可以是RNN或LSTM模型,每一步都根據(jù)上一步的輸出和當(dāng)前的隱藏狀態(tài)生成下一個單詞。

解碼器(Decoder)用于處理長句子的信息衰減問題,允許解碼器在生成每個單詞時對輸入序列的不同部分進(jìn)行不同程度的關(guān)注。注意力機(jī)制(Attention)Seq2Seq網(wǎng)絡(luò)

Seq2Seq網(wǎng)絡(luò)在文本翻譯中被廣泛應(yīng)用。它可以將源語言句子編碼為一個上下文向量,并通過解碼器生成目標(biāo)語言的句子。在訓(xùn)練過程中,給定源語言句子和目標(biāo)語言句子,將源語言句子作為輸入送入編碼器,得到上下文向量。然后,將上下文向量作為解碼器的初始狀態(tài),并通過解碼器逐步生成目標(biāo)語言句子。在測試過程中,給定源語言句子,通過編碼器得到上下文向量,然后使用解碼器生成目標(biāo)語言句子。生成過程中,解碼器會根據(jù)注意力機(jī)制對不同部分的輸入進(jìn)行關(guān)注,以便更好地翻譯長句子或處理句子中的歧義。Seq2Seq網(wǎng)絡(luò)Seq2Seq網(wǎng)絡(luò)LSTM介紹LSTM,全稱LongShortTermMemory(長短期記憶)是一種特殊的遞歸神經(jīng)網(wǎng)絡(luò)Seq2Seq網(wǎng)絡(luò)importtorchimporttorch.nnasnn#定義Seq2Seq模型classSeq2Seq(nn.Module):def__init__(self,input_dim,output_dim,hidden_dim):super(Seq2Seq,self).__init__()self.encoder=nn.LSTM(input_dim,hidden_dim)self.decoder=nn.LSTM(output_dim,hidden_dim)self.fc=nn.Linear(hidden_dim,output_dim)代碼示例Seq2Seq網(wǎng)絡(luò)defforward(self,input_seq,output_seq):_,hidden_state=self.encoder(input_seq)output_seq,_=self.decoder(output_seq,hidden_state)output_seq=self.fc(output_seq)returnoutput_seq#創(chuàng)建Seq2Seq模型實(shí)例input_dim=100#輸入維度output_dim=200#輸出維度hidden_dim=256#隱層維度model=Seq2Seq(input_dim,output_dim,hidden_dim)#定義損失函數(shù)和優(yōu)化器criterion=nn.CrossEntropyLoss()optimizer=torch.optim.Adam(model.parameters(),lr=0.001)Seq2Seq網(wǎng)絡(luò)#執(zhí)行訓(xùn)練過程forepochinrange(num_epochs):optimizer.zero_grad()input_seq,output_seq=get_batch()#獲取批次數(shù)據(jù)pred_seq=model(input_seq,output_seq)loss=criterion(pred_seq,output_seq)loss.backward()optimizer.step()第五節(jié)

注意力機(jī)制Attention注意力機(jī)制(Attention)是一種在序列到序列(Seq2Seq)模型中使用的機(jī)制,用于處理長句子的信息衰減問題。它允許解碼器在生成每個單詞時對輸入序列的不同部分進(jìn)行不同程度的關(guān)注,以便更好地翻譯長句子或處理句子中的歧義。注意力機(jī)制Attention的定義第五節(jié)

注意力機(jī)制Attention注意力機(jī)制是Seq2Seq模型中的關(guān)鍵技術(shù),它能夠提升模型在文本翻譯任務(wù)中的性能,特別是處理長句子和復(fù)雜語義結(jié)構(gòu)時的能力。通過根據(jù)輸入序列的不同部分計(jì)算注意力權(quán)重,模型可以更好地關(guān)注關(guān)鍵信息,從而改善翻譯質(zhì)量。在實(shí)際應(yīng)用中,不同的注意力機(jī)制實(shí)現(xiàn)方式可以根據(jù)具體任務(wù)和數(shù)據(jù)集的特點(diǎn)進(jìn)行選擇和調(diào)整。注意力機(jī)制Attention的定義注意力機(jī)制的核心思想是在解碼器的每個時間步驟中,根據(jù)輸入序列的不同部分對目標(biāo)序列進(jìn)行加權(quán)關(guān)注,以便更準(zhǔn)確地生成目標(biāo)序列的單詞。通過計(jì)算注意力權(quán)重,解碼器可以確定在生成當(dāng)前單詞時對輸入序列的哪些部分更重點(diǎn)關(guān)注,以便捕捉關(guān)鍵信息。注意力機(jī)制的作用是改善模型對長句子的處理能力,減輕信息衰減問題,并提高翻譯質(zhì)量和準(zhǔn)確性。注意力機(jī)制Attention-原理與作用注意力機(jī)制Attention?常見的注意力機(jī)制實(shí)現(xiàn)方式有多種,包括點(diǎn)積注意力(DotProductAttention)、加性注意力(AdditiveAttention)和縮放點(diǎn)積注意力(ScaledDotProductAttention)等。?這些實(shí)現(xiàn)方式的核心是通過計(jì)算注意力權(quán)重,將輸入序列中與解碼器當(dāng)前狀態(tài)相關(guān)的部分加權(quán)求和,作為解碼器的上下文向量,用于生成目標(biāo)語言序列的下一個單詞。Attention的實(shí)現(xiàn)方式第六節(jié)

實(shí)踐任務(wù)該項(xiàng)目的數(shù)據(jù)集是英語與中文的翻譯對的集合。Tatoeba的數(shù)據(jù)源網(wǎng)頁提供了下載,該文件以制表符\t分隔的翻譯對列表,數(shù)據(jù)樣本如下:Hello!你好。CC-BY2.0(France)Attribution:#373330(CK)�(musclegirlxyp)我們需要對每個單詞構(gòu)建唯一索引,以便稍后用作網(wǎng)絡(luò)的輸入和目標(biāo),下面我們看一個輔助類Lang實(shí)踐1準(zhǔn)備數(shù)據(jù)實(shí)踐任務(wù)from__future__importunicode_literals,print_function,divisionfromioimportopenimportunicodedataimportstringimportre#正則表達(dá)式匹配importrandom#隨機(jī)抽樣importnumpyasnpimporttorchimporttorch.nnasnnfromtorchimportoptimimporttorch.nn.functionalasFimportmatplotlibaspltfrommatplotlibimportticker實(shí)踐任務(wù)device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")SOS_token=0EOS_token=1#以上兩個分別代表一個序列的開始和結(jié)束實(shí)踐任務(wù)classLang:def__init__(self,name):=nameself.word2index={}self.word2count={}self.index2word={0:"SOS",1:"EOS"}self.n_words=2下面定義Lang數(shù)據(jù)類,需要設(shè)置3個方法。__init__:初始化addSentence:加入一句話addWord:加入一個詞語實(shí)踐任務(wù)defaddSentence(self,sentence):forwordinsentence.split(""): self.addWord(word)defaddWord(self,word):ifwordnotinself.word2index:self.word2index[word]=self.n_wordsself.word2count[word]=1self.index2word[self.n_words]=wordself.n_words+=1else:self.word2count[word]+=1實(shí)踐任務(wù)defunicodeToAscii(s): return''.join(cforcinunicodedata.normalize('NFD',s)ifunicodedata.category(c)!='Mn')defnormalizeString(s):s=unicodeToAscii(s.lower().strip())s=re.sub(r"([.!?])",r"\1",s)returns另外定義三個函數(shù),分別是:unicodeToAscii:文本編碼格式轉(zhuǎn)換normalizeString:文本標(biāo)準(zhǔn)化readLangs:讀取文本實(shí)踐任務(wù)defreadLangs(lang1,lang2,reverse=False):#讀取文本函數(shù)lines=open(r"data/cmn.txt",encoding="utf-8").read().strip().split("\n")pairs=[[normalizeString(s)forsinl.split("\t")]forlinlines]pairs=np.delete(pairs,2,axis=1)ifreverse:pairs=[list(reversed(p))forpinpairs]input_lang=Lang(lang2)output_lang=Lang(lang1)else:input_lang=Lang(lang1)output_lang=Lang(lang2) returninput_lang,output_lang,pairs實(shí)踐任務(wù)#測試lang1="cmn"lang2="eng"input_lang,output_lang,pairs=readLangs(lang1,lang2)print("pairs中的前五個:",pairs[:5])#運(yùn)行結(jié)果=》pairs中的前五個:[['hi.''嗨。']['hi.''你好。']['run.''你用跑的。']['stop!''住手!']['wait!''等等!']]第七節(jié)

實(shí)踐任務(wù)RNN是一個對序列進(jìn)行操作的網(wǎng)絡(luò),它使用自己的輸出作為后續(xù)步驟的輸入。我們對Seq2Seq網(wǎng)絡(luò)及Attention機(jī)制有一定的了解,也就是利用Encoder和Decoder兩個RNN網(wǎng)絡(luò)一同構(gòu)成的模型進(jìn)行編碼和解碼。編碼器用來讀取輸入序列并輸出上下文向量。而解碼器讀取該向量并產(chǎn)生輸出序列。與使用單個RNN的序列預(yù)測不同,其中每個輸入對應(yīng)于輸出Seq2Seq模型使我們從序列長度和順序中解放出來,這使得其成為兩種語言之間轉(zhuǎn)換的理想選。實(shí)踐2構(gòu)建模型實(shí)踐任務(wù)編碼器的RNN將句子中輸入的每個單詞編碼或隱藏狀態(tài),并逐步循環(huán)傳遞,得到最后一層的隱藏狀態(tài),即上下文向量c。在此我們利用的模型是GRU,GRU模型類似于LSTM,GRU可以說是LSTM的一種優(yōu)化或者變體,LSTM存在三個門,而GRU只有兩個門,參數(shù)更少,更容易收斂。GRU在本節(jié)中的結(jié)構(gòu)如圖所示。編碼器實(shí)踐任務(wù)classEncoderRNN(nn.Module):def__init__(self,input_size,hidden_size):super(EncoderRNN,self).__init__()#初始化必須的變量self.hidden_size=hidden_sizeself.embedding=nn.Embedding(input_size,hidden_size)self.gru=nn.GRU(hidden_size,hidden_size)defforward(self,input,hidden):embedded=self.embedding(input).view(1,1,-1)output=embeddedoutput,hidden=self.gru(output,hidden)returnoutput,hiddendefinitHidden(self):returntorch.zeros(1,1,self.hidden_size,device=device)實(shí)踐任務(wù)解碼器的主要模塊是Decoder模塊,它接手編碼器輸出的上下文向量,并得到一系列的隱藏狀態(tài)用于轉(zhuǎn)換。在解碼的每個步驟中,對Decoder輸入token和hidden,初始token是<SOS>標(biāo)志位,初始hidden是上下文向量。網(wǎng)絡(luò)結(jié)構(gòu)如圖所示。解碼器實(shí)踐任務(wù)classDecoderRNN(nn.Module):def__init__(self,hidden_size,output_size):super(DecoderRNN,self).__init__()self.hidden_size=hidden_sizeself.embedding=nn.Embedding(output_size,hidden_size)self.gru=nn.GRU(hidden_size,hidden_size)self.out=nn.Linear(hidden_size,output_size)self.softmax=nn.LogSoftmax(dim=1)defforward(self,input,hidden):output=self.embedding(input).view(1,1,-1)output=F.relu(output)output,hidden=self.gru(output,hidden)output=self.softmax(self.out(output[0]))returnoutput,hiddendefinitHidden(self):returntorch.zeros(1,1,self.hidden_size,device=device)實(shí)踐任務(wù)如果僅在編碼器和解碼器之間傳遞上下文向量,這單個向量將承擔(dān)整個句子的所有信息,而注意力機(jī)制允許Decoder去針對Encoder的每步輸出進(jìn)行聚焦,最后在形成有針對性的上下文向量,也就是上下文向量會隨著每一步而變換。實(shí)現(xiàn)代碼如下:注意力機(jī)制解碼器實(shí)踐任務(wù)classAttnDecoderRNN(nn.Module): def__init__(self,hidden_size,output_size,dropout_p=0.1,max_length=MAX_LENGTH):super(AttnDecoderRNN,self).__init__()self.hidden_size=hidden_sizeself.output_size=output_sizeself.dropout_p=dropout_pself.max_length=max_lengthself.embedding=nn.Embedding(self.output_size,self.hidden_size)self.attn=nn.Linear(self.hidden_size*2,self.max_length)self.attn_combine=nn.Linear(self.hidden_size*2,self.hidden_size)self.dropout=nn.Dropout(self.dropout_p)self.gru=nn.GRU(self.hidden_size,self.hidden_size)self.out=nn.Linear(self.hidden_size,self.output_size)實(shí)踐任務(wù)defforward(self,input,hidden,encoder_outputs):embedded=self.embedding(input).view(1,1,-1)embedded=self.dropout(embedded)attn_weights=F.softmax(self.attn(torch.cat((embedded[0],hidden[0]),1)),dim=1)attn_applied=torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))output=torch.cat((embedded[0],attn_applied[0]),1)output=self.attn_combine(output).unsqueeze(0)output=F.relu(output)output,hidden=self.gru(output,hidden)output=F.log_softmax(self.out(output[0]),dim=1)returnoutput,hidden,attn_weightsdefinitHidden(self):returntorch.zeros(1,1,self.hidden_size,device=device)實(shí)踐任務(wù)1.準(zhǔn)備訓(xùn)練數(shù)據(jù)我們將數(shù)據(jù)進(jìn)行導(dǎo)入修剪,最后存儲在pairs里,存放著中英文語句對中,為了將數(shù)據(jù)輸入神經(jīng)網(wǎng)絡(luò)中,需要將其轉(zhuǎn)換為張量tensor,包括輸入tensor和目標(biāo)tensor。另外,在創(chuàng)建時,會將EOS標(biāo)志添加到兩個序列中。代碼如下:實(shí)踐3訓(xùn)練模型實(shí)踐任務(wù)defindexesFromSentence(lang,sentence):return[lang.word2index[word]forwordinsentence.split('')]deftensorFromSentence(lang,sentence):indexes=indexesFromSentence(lang,sentence)indexes.append(EOS_token)returntorch.tensor(indexes,dtype=torch.long,device=device).view(-1,1)deftensorsFromPair(pair):input_tensor=tensorFromSentence(input_lang,pair[0])target_tensor=tensorFromSentence(output_lang,pair[1])return(input_tensor,target_tensor)實(shí)踐任務(wù)2.訓(xùn)練技巧TeacherForcingEncoder編碼完成后,開始訓(xùn)練Decoder,那么Decoder每一步的輸入應(yīng)該是什么呢?這里有兩種方法,一是用當(dāng)前步的輸出,也就是模型學(xué)習(xí)后的預(yù)測結(jié)果當(dāng)作下一步的輸入;二是將真實(shí)值用作下一步的輸入,這種概念就是TeacherForcing。這兩種方法都可以采取,但是第二種方法能夠使模型更快收斂,當(dāng)使用受過訓(xùn)練的網(wǎng)絡(luò)時,容易出現(xiàn)不穩(wěn)定性,沒有較好的泛化能力。設(shè)置一個閾值,如果一個隨機(jī)數(shù)比它小,則使用TeacharForcing,代碼如下:teacher_forcing_ratio=0.5實(shí)踐任務(wù)3.訓(xùn)練模型為了讓函數(shù)使用方便,我們將訓(xùn)練包裝成一個train函數(shù),后面迭代時逐一將向量傳進(jìn)去。代碼如下:實(shí)踐任務(wù)deftrain(input_tensor,target_tensor,encoder,decoder,encoder_optimizer,decoder_optimizer,criterion,max_length=MAX_LENGTH):#初始化隱藏狀態(tài)encoder_hidden=encoder.initHidden()#梯度清零encoder_optimizer.zero_grad()decoder_optimizer.zero_grad()input_length=input_tensor.size(0)target_length=target_tensor.size(0)#初始化,等會替換encoder_outputs=torch.zeros(max_length,encoder.hidden_size,device=device)loss=0實(shí)踐任務(wù)foreiinrange(input_length):encoder_output,encoder_hidden=encoder(input_tensor[ei],encoder_hidden)#encoder_output.size()==>tensor([1,1,hidden_size])encoder_outputs[ei]=encoder_output[0,0]#輸入為<sos>,decoder初始隱藏狀態(tài)為encoder的decoder_input=torch.tensor([[SOS_token]],device=device)decoder_hidden=encoder_hidden#隨機(jī)決定是否采用teacher_forcinguse_teacher_forcing=Trueifrandom.random()<teacher_forcing_ratioelse

False實(shí)踐任務(wù)ifuse_teacher_forcing:fordiinrange(target_length):decoder_output,decoder_hidden,decoder_attention=decoder(decoder_input,decoder_hidden,encoder_outputs)loss+=criterion(decoder_output,target_tensor[di])else:

#若不用,則用預(yù)測出的結(jié)果作為Decoder下一個輸入fordiinrange(target_length):decoder_output,decoder_hidden,decoder_attention=decoder(decoder_input,decoder_hidden,encoder_outputs)topv,topi=decoder_output.topk(1)decoder_input=topi.squeeze().detach()loss+=criterion(decoder_output,target_tensor[di])ifdecoder_input.item()==EOS_token:break實(shí)踐任務(wù)loss.backward()#參數(shù)更新encoder_optimizer.step()decoder_optimizer.step()#返回平均lossreturnloss.item()/target_length實(shí)踐任務(wù)接下來開始訓(xùn)練模型,為此定義了輔助函數(shù)timeSince用來計(jì)時,整個訓(xùn)練過程如下:

(1)啟動計(jì)時器; (2)初始化優(yōu)化器和損失函數(shù); (3)創(chuàng)建一組訓(xùn)練對; (4)代入train函數(shù)進(jìn)行迭代訓(xùn)練。實(shí)踐任務(wù)importtimeimportmathdefasMinutes(s):m=math.floor(s/60)s-=m*60return'%dm%ds'%(m,s)deftimeSince(since,percent):now=time.time()s=now-sincees=s/(percent)rs=es-sreturn'%s(-%s)'%(asMinutes(s),asMinutes(rs))實(shí)踐任務(wù)deftrainIters(encoder,decoder,n_iters,print_every=1000,plot_every=100,learning_rate=0.01):start=time.time()plot_losses=[]#每一次重置print_loss_total=0plot_loss_total=0#定義優(yōu)化器encoder_optimizer=optim.SGD(encoder.parameters(),lr=learning_rate)

decoder_optimizer=optim.SGD(decoder.parameters(),lr=learning_rate)#random.choice(pairs)隨機(jī)選擇training_pairs=[tensorsFromPair(random.choice(pairs))foriinrange(n_iters)]criterion=nn.NLLLoss()實(shí)踐任務(wù)foriterinrange(1,n_iters+1):training_pair=training_pairs[iter-1]input_tensor=training_pair[0]target_tensor=training_pair[1]loss=train(input_tensor,target_tensor,encoder,decoder,encoder_optimizer,decoder_optimizer,criterion)print_loss_total+=lossplot_loss_total+=lossifiter%print_every==0:print_loss_avg=print_loss_total/print_everyprint_loss_total=0print('%s(%d%d%%)%.4f'%(timeSince(start,iter/n_iters),iter,iter/n_iters*100,print_loss_avg))實(shí)踐任務(wù)defshowPlot(points):plt.figure()fig,ax=plt.subplots()#thislocatorputsticksatregularintervalsloc=ticker.MultipleLocator(base=0.2)ax.yaxis.set_major_locator(loc)plt.plot(points)hidden_size=256encoder1=EncoderRNN(input_lang.n_words,hidden_size).to(device)attn_decoder1=AttnDecoderRNN(hidden_size,output_lang.n_words,dropout_p=0.1).to(device)trainIters(encoder1,attn_decoder1,75000,print_every=5000)實(shí)踐任務(wù)運(yùn)行后,輸出結(jié)果如下:0m28s(-6m41s)(50006%)3.10630m55s(-6m3s)(1000013%)2.07651m23s(-5m35s)(1500020%)0.67331m51s(-5m6s)(2000026%)0.16932m18s(-4m37s)(2500033%)0.08442m45s(-4m7s)(3000040%)0.06673m11s(-3m38s)(3500046%)0.05493m37s(-3m

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論