版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請(qǐng)進(jìn)行舉報(bào)或認(rèn)領(lǐng)
文檔簡介
第10章航班乘客數(shù)預(yù)測第10章航班乘客數(shù)預(yù)測10.1PyTorch簡介10.2安裝PyTorch10.3導(dǎo)入相關(guān)庫10.4PyTorch基礎(chǔ)知識(shí)10.5讀取數(shù)據(jù)10.6數(shù)據(jù)預(yù)處理10.7定義網(wǎng)絡(luò)模型10.8定義損失函數(shù)和優(yōu)化器10.9訓(xùn)練模型10.10測試模型第10章航班乘客數(shù)預(yù)測10.1PyTorch簡介PyTorch是由Facebook開發(fā),基于Torch開發(fā),從并不常用的Lua語言轉(zhuǎn)為Python語言開發(fā)的深度學(xué)習(xí)框架,可以用于構(gòu)建深度神經(jīng)網(wǎng)絡(luò)。Pytorch是一個(gè)基于Python的科學(xué)計(jì)算庫,它面向以下兩種人群:希望將其代替Numpy來利用GPUs的威力;一個(gè)可以提供更加靈活和快速的深度學(xué)習(xí)研究平臺(tái)。第10章航班乘客數(shù)預(yù)測10.2安裝PyTorchPyTorch的安裝可以直接查看官網(wǎng)教程,如下所示,官網(wǎng)地址:/get-started/locally/第10章航班乘客數(shù)預(yù)測10.2安裝PyTorch第10章航班乘客數(shù)預(yù)測10.3導(dǎo)入相關(guān)庫import
torchimport
torch.nn
as
nn
import
numpy
as
npimport
matplotlib.pyplot
as
pltplt.rcParams['font.sans-serif']=['simsun']
#設(shè)置加載的字體名plt.rcParams['axes.unicode_minus']=False
#解決保存圖像是負(fù)號(hào)'-'顯示為方塊的問題第10章航班乘客數(shù)預(yù)測10.4PyTorch基礎(chǔ)知識(shí)10.4.1張量(1)創(chuàng)建一個(gè)張量x=torch.Tensor([1,2,3])
#創(chuàng)建一個(gè)1維張量y=torch.Tensor([[1,2],[3,4]])
#創(chuàng)建一個(gè)2維張量z=torch.Tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
#創(chuàng)建一個(gè)3維張量xyz輸出結(jié)果:第10章航班乘客數(shù)預(yù)測10.4PyTorch基礎(chǔ)知識(shí)10.4.1張量(2)張量的形狀z.shape
#獲取張量的形狀z.size()
#獲取張量的形狀z.view(2,4)
#改變張量的形狀,2行,4列z.reshape(1,8)
#改變張量的形狀,1行,8列z.resize_(2,4)
#直接修改原始張量的形狀,2行,4列輸出結(jié)果:第10章航班乘客數(shù)預(yù)測10.4PyTorch基礎(chǔ)知識(shí)10.4.2自動(dòng)微分PyTorch提供了自動(dòng)微分功能,可以自動(dòng)計(jì)算梯度,這使得模型訓(xùn)練更加容易。我們使用torch.tensor()來定義張量,然后使用.backward()函數(shù)計(jì)算梯度。x=torch.tensor(2.0,requires_grad=True)#定義張量x,并將requires_grad設(shè)置為True,以便PyTorch跟蹤它的計(jì)算歷史y=x**2
#定義新的張量y,它是x的平方y(tǒng).backward()#調(diào)用y.backward()來計(jì)算y相對(duì)于x的導(dǎo)數(shù)x.grad
#打印出結(jié)果為tensor(4.)第10章航班乘客數(shù)預(yù)測10.4PyTorch基礎(chǔ)知識(shí)10.4.3神經(jīng)網(wǎng)絡(luò)PyTorch提供了torch.nn模塊,可以幫助開發(fā)者更輕松地構(gòu)建和訓(xùn)練神經(jīng)網(wǎng)絡(luò)模型??梢允褂胻orch.nn.Module()類定義神經(jīng)網(wǎng)絡(luò)模型,然后使用torch.optim優(yōu)化器進(jìn)行訓(xùn)練。10.4.4數(shù)據(jù)加載PyTorch提供了torch.utils.data模塊,可以幫助開發(fā)者更輕松地加載和處理數(shù)據(jù)。可以使用torch.utils.data.Dataset()類定義數(shù)據(jù)集,然后使用torch.utils.data,DataLoader()函數(shù)加載數(shù)據(jù)。10.4.5GPU加速PyTorch可以使用GPU加速,可以使用torch.cuda模塊將張量和模型移動(dòng)到GPU上運(yùn)行。第10章航班乘客數(shù)預(yù)測10.5讀取數(shù)據(jù)with
open("data\international-airline-passengers.csv","r",encoding="utf-8")asf:
next(f)#跳過第1行
data_csv=f.read()
#將文件內(nèi)容讀取到變量data中data=[row.split(',')forrowin
data_csv.split("\n")]#將字符串變量data_csv中的每一行按逗號(hào)分隔并返回一個(gè)列表。這個(gè)列表包含了每一行的元素。passengers=[int(each[1])foreachindata]#將列表變量data中的每個(gè)元素的第二個(gè)字符轉(zhuǎn)換為整數(shù)并返回一個(gè)新的列表。Passengers#打印前10個(gè)月中每月的航班乘客數(shù)輸出結(jié)果:第10章航班乘客數(shù)預(yù)測10.6數(shù)據(jù)預(yù)處理接下來,我們首先使用滑動(dòng)窗口方法創(chuàng)建基于航班乘客數(shù)的時(shí)間序列數(shù)據(jù)。然后,將序列數(shù)據(jù)轉(zhuǎn)換成滿足模型輸入要求的訓(xùn)練數(shù)據(jù)集和測試數(shù)據(jù)集。這樣,我們就可以使用前2天的航班乘客數(shù)來預(yù)測第3天的航班乘客數(shù)。第10章航班乘客數(shù)預(yù)測10.7定義網(wǎng)絡(luò)模型class
Net(nn.Module):
#初始化函數(shù),定義網(wǎng)絡(luò)結(jié)構(gòu)
def
__init__(self):
#調(diào)用父類的初始化函數(shù)
super(Net,self).__init__()
#定義一個(gè)LSTM層,輸入特征為1(只有乘客數(shù)),隱藏狀態(tài)大小為32,層數(shù)為1,batch_first為True
self.lstm=nn.LSTM(input_size=1,hidden_size=32,num_layers=1,batch_first=True)
#定義一個(gè)線性層,將32*seq_len個(gè)輸入特征映射到1個(gè)輸出特征(預(yù)測下一月乘客數(shù))
self.linear=nn.Linear(32*seq_len,1)
#前向傳播函數(shù)
def
forward(self,input):
#將輸入input輸入到LSTM層中,得到輸出結(jié)果lstm_out,隱藏狀態(tài)h和單元狀態(tài)c
lstm_out,(h,c)=self.lstm(input)
#將lstm_out進(jìn)行reshape,變成一個(gè)形狀為(-1,32*seq_len)的張量
x=lstm_out.reshape(-1,32*seq_len)
#將x輸入到線性層中,得到輸出pred
pred=self.linear(x)
#返回輸出pred
return
pred第10章航班乘客數(shù)預(yù)測10.8定義損失函數(shù)和優(yōu)化器model=Net()#定義一個(gè)Adam優(yōu)化器,用于更新模型參數(shù),學(xué)習(xí)率為0.003optimizer=torch.optim.Adam(model.parameters(),lr=0.003)#定義一個(gè)均方誤差損失函數(shù),用于計(jì)算模型預(yù)測值與真實(shí)值之間的誤差loss_fun=nn.MSELoss()第10章航班乘客數(shù)預(yù)測10.9訓(xùn)練模型#將模型設(shè)置為訓(xùn)練模式model.train()#進(jìn)行300輪訓(xùn)練for
epoch
in
range(300):
#將訓(xùn)練數(shù)據(jù)train_x輸入到模型中,得到模型的輸出output
output=model(train_x)
#計(jì)算模型輸出output與訓(xùn)練標(biāo)簽train_y之間的均方誤差損失
loss=loss_fun(output,train_y)
#將優(yōu)化器的梯度清零
optimizer.zero_grad()
#反向傳播計(jì)算梯度
loss.backward()
#使用優(yōu)化器更新模型參數(shù)
optimizer.step()
#每20輪輸出一次訓(xùn)練損失和測試損失
if
epoch%20==0
and
epoch>0:
#將測試數(shù)據(jù)test_x輸入到模型中,得到模型的輸出output
#計(jì)算模型輸出output與測試標(biāo)簽test_y之間的均方誤差損失
test_loss=loss_fun(model(test_x),test_y)
#輸出當(dāng)前輪數(shù)、訓(xùn)練損失和測試損失
print("epoch:{},loss:{},test_loss:{}".format(epoch,loss,test_loss))第10章航班乘客數(shù)預(yù)測10.9訓(xùn)練模型第10章航班乘客數(shù)預(yù)測10.10測試模型第10章航班乘客數(shù)預(yù)測10.10測試模型#將模型設(shè)置為評(píng)估模式model.eval()#構(gòu)造預(yù)測結(jié)果result=X[0][:seq_len-1]
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請(qǐng)下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請(qǐng)聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
- 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請(qǐng)與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。
最新文檔
- 《 以α-羥基酸鈰為前驅(qū)體制備二氧化鈰及其性能的研究》范文
- 《 不同放牧強(qiáng)度對(duì)錫林郭勒草甸草原群落及植物功能性狀的影響》范文
- 通信設(shè)備在電子商務(wù)交易的安全保障考核試卷
- 酒吧服務(wù)飲品市場營銷策略分析考核試卷
- 陶瓷生產(chǎn)設(shè)備維護(hù)與管理考核試卷
- 醫(yī)療設(shè)備租賃行業(yè)發(fā)展趨勢考核試卷
- 地質(zhì)勘探中的安全風(fēng)險(xiǎn)評(píng)估考核試卷
- 舟山市普陀山莊有限公司招聘筆試題庫2024
- 中石油云南石化有限公司招聘筆試題庫2024
- 中石油共享運(yùn)營有限公司招聘筆試題庫2024
- 中國空白地圖大全(可直接打印)(共49頁)
- 麥彭仁波切教言
- 冀教版五年級(jí)英語上冊Unit1單元測試卷(含聽力材料及答案)
- 航海英語大副批注
- 小學(xué)二年級(jí)綜合實(shí)踐活動(dòng)課神奇的風(fēng)PPT精品文檔
- 公司員工內(nèi)部調(diào)查問卷_公司員工調(diào)查問卷
- 中考英語聽說之口頭作文
- 貓的介紹-英文PPT_圖文.ppt
- 消化道出血PPT優(yōu)秀課件
- 授 權(quán) 委 托 書天津北方網(wǎng)——權(quán)威媒體 天津門戶
- 消防中控室操作流程
評(píng)論
0/150
提交評(píng)論