深度學(xué)習(xí)案例教程 課件3.4.2全連接神經(jīng)網(wǎng)絡(luò)搭建_第1頁(yè)
深度學(xué)習(xí)案例教程 課件3.4.2全連接神經(jīng)網(wǎng)絡(luò)搭建_第2頁(yè)
深度學(xué)習(xí)案例教程 課件3.4.2全連接神經(jīng)網(wǎng)絡(luò)搭建_第3頁(yè)
深度學(xué)習(xí)案例教程 課件3.4.2全連接神經(jīng)網(wǎng)絡(luò)搭建_第4頁(yè)
深度學(xué)習(xí)案例教程 課件3.4.2全連接神經(jīng)網(wǎng)絡(luò)搭建_第5頁(yè)
已閱讀5頁(yè),還剩7頁(yè)未讀 繼續(xù)免費(fèi)閱讀

下載本文檔

版權(quán)說(shuō)明:本文檔由用戶(hù)提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請(qǐng)進(jìn)行舉報(bào)或認(rèn)領(lǐng)

文檔簡(jiǎn)介

實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建1定義全連接神經(jīng)網(wǎng)絡(luò)

在PyTorch中,

torch.nn是專(zhuān)門(mén)為神經(jīng)網(wǎng)絡(luò)設(shè)計(jì)的模塊化接口,,可以用于定義和運(yùn)行神經(jīng)網(wǎng)絡(luò)。

nn.Module是nn庫(kù)中十分重要的類(lèi),它包含網(wǎng)絡(luò)各層的定義以及forward函數(shù)。

只要在nn.Module的子類(lèi)中定義forward函數(shù),backward函數(shù)就會(huì)被自動(dòng)實(shí)現(xiàn)(利用autograd)。

實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建1定義全連接神經(jīng)網(wǎng)絡(luò)

importtorch.nnasnn#導(dǎo)入nn庫(kù)

classNeuralNet(nn.Module):

def__init__(self,input_num,hidden_num,output_num):

super(NeuralNet,self).__init__()

self.fc1=nn.Linear(input_num,hidden_num)

self.fc2=nn.Linear(hidden_num,output_num)

self.relu=nn.ReLU()

defforward(self,x):

x=self.fc1(x)

x=self.relu(x)

y=self.fc2(x)

returny

實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建1定義全連接神經(jīng)網(wǎng)絡(luò)

#設(shè)置參數(shù)

epoches=20

lr=0.001

input_num=784

hidden_num=500

output_num=10

device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")

#創(chuàng)建網(wǎng)絡(luò)模型

model=NeuralNet(input_num,hidden_num,output_num)

print(model)實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建1定義全連接神經(jīng)網(wǎng)絡(luò)

運(yùn)行后輸出如下:NeuralNet((fc1):Linear(in_features=784,out_features=500,bias=True)(fc2):Linear(in_features=500,out_features=10,bias=True)(relu):ReLU())實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建2前向傳播

定義好網(wǎng)絡(luò)模型后,我們會(huì)將所有的數(shù)據(jù)按照batch的方式進(jìn)行輸入,得到對(duì)應(yīng)的網(wǎng)絡(luò)輸出,這就是所有的前向傳播。

實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建2前向傳播

#前向傳播

images=images.reshape(-1,28*28)

image=images[:2]

label=labels[:2]

print(image.size())

print(label)

out=model(image)

print(out)實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建2前向傳播

運(yùn)行程序輸出的結(jié)果如下:torch.Size([2,784])tensor([0,6])tensor([[0.1336,0.2989,0.1140,-0.0331,0.1986,-0.1656,-0.1346,0.1204,-0.3536,0.2364],[0.3198,0.3422,-0.2137,0.2526,0.3694,-0.0444,-0.1710,-0.0321,0.1679,-0.2004]],grad_fn=<AddmmBackward0>)實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建

3計(jì)算損失

損失函數(shù)需要一對(duì)輸入:模型輸出和目標(biāo),用來(lái)評(píng)估輸出值和目標(biāo)值之間的差距,損失函數(shù)用loss表示,損失函數(shù)的作用就是計(jì)算神經(jīng)網(wǎng)絡(luò)每次迭代的前向計(jì)算結(jié)果和真實(shí)值之間的差距,從而指導(dǎo)模型下一步訓(xùn)練往正確的方向進(jìn)行。常見(jiàn)的損失函數(shù)有交叉熵?fù)p失函數(shù)和均方誤差損失函數(shù)。實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建3計(jì)算損失在PyTorch中,nn庫(kù)模塊提供了多種損失函數(shù),常用的有以下幾種:1.處理回歸問(wèn)題的nn.MSELoss函數(shù),2.處理分類(lèi)問(wèn)題的nn.BCELoss函數(shù),3.處理多分類(lèi)問(wèn)題的nn.CrossEntropyLoss函數(shù)。實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建3計(jì)算損失#定義損失函數(shù)

criterion=nn.CrossEntropyLoss()

loss=criterion(out,label)

print(loss)實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建

4反向傳播與參數(shù)更新

當(dāng)計(jì)算出一次前向傳播的loss值之后,可進(jìn)行反向傳播計(jì)算梯度,以此來(lái)更新參數(shù)。在PyTorch中,對(duì)loss調(diào)用backward函數(shù)即可。backward函數(shù)屬于torch.autograd函數(shù)庫(kù),在深度學(xué)習(xí)過(guò)程中進(jìn)行反向傳播,計(jì)算輸出變量關(guān)于輸入變量的梯度。最后要做的事情就是更新神經(jīng)網(wǎng)絡(luò)的參數(shù),最簡(jiǎn)單的規(guī)則就是隨機(jī)梯度下降,公式如下:weight=weight-learningrate×gradient當(dāng)然,還有很多不同的更新規(guī)則,類(lèi)似于SGD、Adam、RMSProp等,為了讓這些可行,PyTorch建立了一個(gè)torch.optim包,調(diào)用它可以實(shí)現(xiàn)上述任意一種優(yōu)化器。實(shí)踐任務(wù)2-全連接神經(jīng)網(wǎng)絡(luò)模型搭建

4反向傳播與參數(shù)更新#創(chuàng)建優(yōu)化器

importtorch.optimasoptim

optimizer=optim.SGD(model.parameters(),lr=0.01)#lr代表學(xué)習(xí)率

criterion=nn.CrossEntropyLoss()

#在訓(xùn)練過(guò)程中

image=images[:2]

l

溫馨提示

  • 1. 本站所有資源如無(wú)特殊說(shuō)明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請(qǐng)下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請(qǐng)聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶(hù)所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁(yè)內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒(méi)有圖紙預(yù)覽就沒(méi)有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫(kù)網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶(hù)上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶(hù)上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請(qǐng)與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶(hù)因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。

評(píng)論

0/150

提交評(píng)論