Bài 11: Recurrent Neural Network ( RNN ) - Pytorch Cơ bản

Đăng bởi: Admin | Lượt xem: 3954 | Chuyên mục: AI


Recurrent neural networks ( RNN ) là một loại thuật toán định hướng học sâu theo cách tiếp cận tuần tự. Trong mạng nơ-ron, ta luôn giả định rằng mỗi đầu vào và đầu ra là độc lập với tất cả các lớp khác. Loại mạng nơ-ron này được gọi là mạng lặp lại vì chúng thực hiện các phép tính toán học một cách tuần tự hoàn thành nhiệm vụ này đến tác vụ khác.
Sơ đồ dưới đây chỉ rõ cách tiếp cận hoàn chỉnh và hoạt động của mạng RNN
Trong hình trên, c1, c2, c3 và x1 được coi là đầu vào bao gồm một số giá trị đầu vào ẩn cụ thể là h1, h2 và h3 cung cấp đầu ra tương ứng là o1. Bây giờ ta sẽ tập trung vào việc triển khai PyTorch để tạo ra một sóng sin với sự trợ giúp của các mạng RNN.
Trong quá trình đào tạo, tôi sẽ thực hiện theo cách tiếp cận đào tạo mô hình với một điểm dữ liệu tại một thời điểm. Chuỗi đầu vào x bao gồm 20 điểm dữ liệu và chuỗi đích được coi là giống với chuỗi đầu vào.

Bước 1 :

Import các package cần thiết
import torch
from torch.autograd import Variable
import numpy as np
import pylab as pl
import torch.nn.init as init

Bước 2

Ta sẽ thiết lập các tham số siêu mô hình với kích thước của lớp đầu vào được đặt là 7. Sẽ có 6 neuron context và 1 neuron đầu vào để tạo chuỗi đích.
dtype = torch.FloatTensor
input_size, hidden_size, output_size = 7, 6, 1
epochs = 300
seq_length = 20
lr = 0.1
data_time_steps = np.linspace(2, 10, seq_length + 1)
data = np.sin(data_time_steps)
data.resize((seq_length + 1, 1))

x = Variable(torch.Tensor(data[:-1]).type(dtype), requires_grad=False)
y = Variable(torch.Tensor(data[1:]).type(dtype), requires_grad=False)
Tạo dữ liệu đào tạo, trong đó x là chuỗi dữ liệu đầu vào và y là chuỗi đích bắt buộc.

Bước 3

Trọng số được khởi tạo trong mạng nơ-ron tuần hoàn bằng cách sử dụng phân phối chuẩn với giá trị trung bình bằng không. W1 sẽ đại diện cho việc chấp nhận các biến đầu vào và w2 sẽ đại diện cho đầu ra được tạo như sau:
w1 = torch.FloatTensor(input_size, 
hidden_size).type(dtype)
init.normal(w1, 0.0, 0.4)
w1 = Variable(w1, requires_grad = True)
w2 = torch.FloatTensor(hidden_size, output_size).type(dtype)
init.normal(w2, 0.0, 0.3)
w2 = Variable(w2, requires_grad = True)

Bước 4 :

Bây giờ, điều quan trọng là phải tạo một hàm cho nguồn cấp dữ liệu chuyển tiếp để xác định duy nhất mạng nơ-ron.
def forward(input, context_state, w1, w2):
   xh = torch.cat((input, context_state), 1)
   context_state = torch.tanh(xh.mm(w1))
   out = context_state.mm(w2)
   return (out, context_state)

Bước 5

Bước tiếp theo là bắt đầu đào tạo quy trình triển khai sóng sin của mạng RNN. Vòng lặp bên ngoài lặp qua mỗi vòng lặp và vòng lặp bên trong lặp lại qua phần tử theo trình tự. Ở đây, mình cũng sẽ tính toán Lỗi bình phương trung bình (MSE) giúp dự đoán các biến liên tục.
for i in range(epochs):
   total_loss = 0
   context_state = Variable(torch.zeros((1, hidden_size)).type(dtype), requires_grad = True)
   for j in range(x.size(0)):
      input = x[j:(j+1)]
      target = y[j:(j+1)]
      (pred, context_state) = forward(input, context_state, w1, w2)
      loss = (pred - target).pow(2).sum()/2
      total_loss += loss
      loss.backward()
      w1.data -= lr * w1.grad.data
      w2.data -= lr * w2.grad.data
      w1.grad.data.zero_()
      w2.grad.data.zero_()
      context_state = Variable(context_state.data)
   if i % 10 == 0:
      print("Epoch: {} loss {}".format(i, total_loss.data[0]))

context_state = Variable(torch.zeros((1, hidden_size)).type(dtype), requires_grad = False)
predictions = []

for i in range(x.size(0)):
   input = x[i:i+1]
   (pred, context_state) = forward(input, context_state, w1, w2)
   context_state = context_state
   predictions.append(pred.data.numpy().ravel()[0])

Bước 6

Bây giờ, vẽ biểu đồ sóng sin theo cách sau
pl.scatter(data_time_steps[:-1], x.data.numpy(), s = 90, label = "Actual")
pl.scatter(data_time_steps[1:], predictions, label = "Predicted")
pl.legend()
pl.show()
Kết quả :
Bài tiếp theo: Tập dữ liệu ( Dataset ) >>
vncoder logo

Theo dõi VnCoder trên Facebook, để cập nhật những bài viết, tin tức và khoá học mới nhất!