PyTorch - Jaringan Neural Berulang
Jaringan saraf rekuren adalah salah satu jenis algoritma berorientasi pembelajaran dalam yang mengikuti pendekatan sekuensial. Dalam jaringan neural, kami selalu menganggap bahwa setiap input dan output tidak bergantung pada semua lapisan lainnya. Jenis jaringan saraf ini disebut berulang karena mereka melakukan perhitungan matematika secara berurutan menyelesaikan satu tugas demi tugas.
Diagram di bawah menentukan pendekatan lengkap dan cara kerja jaringan saraf berulang -
Pada gambar di atas, c1, c2, c3 dan x1 dianggap sebagai input yang mencakup beberapa nilai input tersembunyi yaitu h1, h2 dan h3 yang memberikan output masing-masing o1. Kami sekarang akan fokus pada penerapan PyTorch untuk membuat gelombang sinus dengan bantuan jaringan saraf berulang.
Selama pelatihan, kita akan mengikuti pendekatan pelatihan untuk model kita dengan satu titik data pada satu waktu. Urutan masukan x terdiri dari 20 titik data, dan urutan target dianggap sama dengan urutan masukan.
Langkah 1
Impor paket yang diperlukan untuk mengimplementasikan jaringan neural berulang menggunakan kode di bawah ini -
import torch
from torch.autograd import Variable
import numpy as np
import pylab as pl
import torch.nn.init as init
Langkah 2
Kita akan mengatur parameter hyper model dengan ukuran input layer menjadi 7. Akan ada 6 neuron konteks dan 1 neuron input untuk membuat urutan target.
dtype = torch.FloatTensor
input_size, hidden_size, output_size = 7, 6, 1
epochs = 300
seq_length = 20
lr = 0.1
data_time_steps = np.linspace(2, 10, seq_length + 1)
data = np.sin(data_time_steps)
data.resize((seq_length + 1, 1))
x = Variable(torch.Tensor(data[:-1]).type(dtype), requires_grad=False)
y = Variable(torch.Tensor(data[1:]).type(dtype), requires_grad=False)
Kami akan menghasilkan data pelatihan, di mana x adalah urutan data masukan dan y diperlukan urutan target.
LANGKAH 3
Bobot diinisialisasi di jaringan saraf berulang menggunakan distribusi normal dengan rata-rata nol. W1 akan mewakili penerimaan variabel input dan w2 akan mewakili output yang dihasilkan seperti yang ditunjukkan di bawah ini -
w1 = torch.FloatTensor(input_size,
hidden_size).type(dtype)
init.normal(w1, 0.0, 0.4)
w1 = Variable(w1, requires_grad = True)
w2 = torch.FloatTensor(hidden_size, output_size).type(dtype)
init.normal(w2, 0.0, 0.3)
w2 = Variable(w2, requires_grad = True)
LANGKAH 4
Sekarang, penting untuk membuat fungsi feed forward yang secara unik mendefinisikan jaringan saraf.
def forward(input, context_state, w1, w2):
xh = torch.cat((input, context_state), 1)
context_state = torch.tanh(xh.mm(w1))
out = context_state.mm(w2)
return (out, context_state)
LANGKAH 5
Langkah selanjutnya adalah memulai prosedur pelatihan implementasi gelombang sinus jaringan saraf rekuren. Loop luar melakukan iterasi pada setiap loop dan loop dalam melakukan iterasi melalui elemen urutan. Di sini, kami juga akan menghitung Mean Square Error (MSE) yang membantu dalam prediksi variabel kontinu.
for i in range(epochs):
total_loss = 0
context_state = Variable(torch.zeros((1, hidden_size)).type(dtype), requires_grad = True)
for j in range(x.size(0)):
input = x[j:(j+1)]
target = y[j:(j+1)]
(pred, context_state) = forward(input, context_state, w1, w2)
loss = (pred - target).pow(2).sum()/2
total_loss += loss
loss.backward()
w1.data -= lr * w1.grad.data
w2.data -= lr * w2.grad.data
w1.grad.data.zero_()
w2.grad.data.zero_()
context_state = Variable(context_state.data)
if i % 10 == 0:
print("Epoch: {} loss {}".format(i, total_loss.data[0]))
context_state = Variable(torch.zeros((1, hidden_size)).type(dtype), requires_grad = False)
predictions = []
for i in range(x.size(0)):
input = x[i:i+1]
(pred, context_state) = forward(input, context_state, w1, w2)
context_state = context_state
predictions.append(pred.data.numpy().ravel()[0])
LANGKAH 6
Sekarang, saatnya memplot gelombang sinus sesuai kebutuhan.
pl.scatter(data_time_steps[:-1], x.data.numpy(), s = 90, label = "Actual")
pl.scatter(data_time_steps[1:], predictions, label = "Predicted")
pl.legend()
pl.show()
Keluaran
Output dari proses di atas adalah sebagai berikut -