PyTorch - Tekrarlayan Sinir Ağı

Tekrarlayan sinir ağları, sıralı bir yaklaşımı izleyen bir tür derin öğrenme odaklı algoritmadır. Sinir ağlarında, her zaman her bir girişin ve çıkışın diğer tüm katmanlardan bağımsız olduğunu varsayıyoruz. Bu tür sinir ağlarına yinelenen adı verilir, çünkü matematiksel hesaplamaları birbiri ardına tamamlayarak sıralı bir şekilde gerçekleştirirler.

Aşağıdaki diyagram, tekrarlayan sinir ağlarının tam yaklaşımını ve çalışmasını göstermektedir -

Yukarıdaki şekilde, c1, c2, c3 ve x1, o1'in ilgili çıkışını sağlayan h1, h2 ve h3 gibi bazı gizli giriş değerlerini içeren girişler olarak kabul edilir. Şimdi tekrarlayan sinir ağlarının yardımıyla bir sinüs dalgası oluşturmak için PyTorch'u uygulamaya odaklanacağız.

Eğitim sırasında, modelimize bir seferde bir veri noktası ile bir eğitim yaklaşımı izleyeceğiz. X giriş dizisi 20 veri noktasından oluşur ve hedef sıranın giriş dizisi ile aynı olduğu kabul edilir.

Aşama 1

Aşağıdaki kodu kullanarak tekrarlayan sinir ağlarını uygulamak için gerekli paketleri içe aktarın -

import torch
from torch.autograd import Variable
import numpy as np
import pylab as pl
import torch.nn.init as init

Adım 2

Model hiper parametrelerini, giriş katmanı boyutu 7 olarak ayarlayacağız. Hedef sekansı oluşturmak için 6 bağlam nöronu ve 1 giriş nöronu olacak.

dtype = torch.FloatTensor
input_size, hidden_size, output_size = 7, 6, 1
epochs = 300
seq_length = 20
lr = 0.1
data_time_steps = np.linspace(2, 10, seq_length + 1)
data = np.sin(data_time_steps)
data.resize((seq_length + 1, 1))

x = Variable(torch.Tensor(data[:-1]).type(dtype), requires_grad=False)
y = Variable(torch.Tensor(data[1:]).type(dtype), requires_grad=False)

Eğitim verilerini oluşturacağız, burada x girdi veri dizisi ve y gerekli hedef sıra.

Aşama 3

Ağırlıklar, sıfır ortalama ile normal dağılım kullanılarak tekrarlayan sinir ağında başlatılır. W1, giriş değişkenlerinin kabulünü temsil edecek ve w2, aşağıda gösterildiği gibi üretilen çıktıyı temsil edecektir -

w1 = torch.FloatTensor(input_size, 
hidden_size).type(dtype)
init.normal(w1, 0.0, 0.4)
w1 = Variable(w1, requires_grad = True)
w2 = torch.FloatTensor(hidden_size, output_size).type(dtype)
init.normal(w2, 0.0, 0.3)
w2 = Variable(w2, requires_grad = True)

4. adım

Şimdi, sinir ağını benzersiz bir şekilde tanımlayan ileri besleme için bir işlev oluşturmak önemlidir.

def forward(input, context_state, w1, w2):
   xh = torch.cat((input, context_state), 1)
   context_state = torch.tanh(xh.mm(w1))
   out = context_state.mm(w2)
   return (out, context_state)

Adım 5

Bir sonraki adım, tekrarlayan sinir ağının sinüs dalgası uygulamasının eğitim prosedürüne başlamaktır. Dış döngü, her döngü üzerinde yinelenir ve iç döngü, dizi öğesi boyunca yinelenir. Burada, sürekli değişkenlerin tahminine yardımcı olan Ortalama Kare Hatasını (MSE) de hesaplayacağız.

for i in range(epochs):
   total_loss = 0
   context_state = Variable(torch.zeros((1, hidden_size)).type(dtype), requires_grad = True)
   for j in range(x.size(0)):
      input = x[j:(j+1)]
      target = y[j:(j+1)]
      (pred, context_state) = forward(input, context_state, w1, w2)
      loss = (pred - target).pow(2).sum()/2
      total_loss += loss
      loss.backward()
      w1.data -= lr * w1.grad.data
      w2.data -= lr * w2.grad.data
      w1.grad.data.zero_()
      w2.grad.data.zero_()
      context_state = Variable(context_state.data)
   if i % 10 == 0:
      print("Epoch: {} loss {}".format(i, total_loss.data[0]))

context_state = Variable(torch.zeros((1, hidden_size)).type(dtype), requires_grad = False)
predictions = []

for i in range(x.size(0)):
   input = x[i:i+1]
   (pred, context_state) = forward(input, context_state, w1, w2)
   context_state = context_state
   predictions.append(pred.data.numpy().ravel()[0])

6. Adım

Şimdi, sinüs dalgasını ihtiyaç duyulan şekilde çizme zamanı.

pl.scatter(data_time_steps[:-1], x.data.numpy(), s = 90, label = "Actual")
pl.scatter(data_time_steps[1:], predictions, label = "Predicted")
pl.legend()
pl.show()

Çıktı

Yukarıdaki işlemin çıktısı aşağıdaki gibidir -