PyBrain - Tập dữ liệu đào tạo trên mạng

Cho đến nay, chúng ta đã thấy cách tạo một mạng và một tập dữ liệu. Để làm việc với bộ dữ liệu và mạng cùng nhau, chúng ta phải làm điều đó với sự trợ giúp của giảng viên.

Dưới đây là một ví dụ hoạt động để xem cách thêm tập dữ liệu vào mạng được tạo và sau đó được đào tạo và thử nghiệm bằng cách sử dụng trình đào tạo.

testnetwork.py

from pybrain.tools.shortcuts import buildNetwork
from pybrain.structure import TanhLayer
from pybrain.datasets import SupervisedDataSet
from pybrain.supervised.trainers import BackpropTrainer

# Create a network with two inputs, three hidden, and one output
nn = buildNetwork(2, 3, 1, bias=True, hiddenclass=TanhLayer)

# Create a dataset that matches network input and output sizes:
norgate = SupervisedDataSet(2, 1)

# Create a dataset to be used for testing.
nortrain = SupervisedDataSet(2, 1)

# Add input and target values to dataset
# Values for NOR truth table
norgate.addSample((0, 0), (1,))
norgate.addSample((0, 1), (0,))
norgate.addSample((1, 0), (0,))
norgate.addSample((1, 1), (0,))

# Add input and target values to dataset
# Values for NOR truth table
nortrain.addSample((0, 0), (1,))
nortrain.addSample((0, 1), (0,))
nortrain.addSample((1, 0), (0,))
nortrain.addSample((1, 1), (0,))

#Training the network with dataset norgate.
trainer = BackpropTrainer(nn, norgate)

# will run the loop 1000 times to train it.
for epoch in range(1000):
trainer.train()
trainer.testOnData(dataset=nortrain, verbose = True)

Để kiểm tra mạng và tập dữ liệu, chúng ta cần BackpropTrainer. BackpropTrainer là trình đào tạo huấn luyện các tham số của mô-đun theo tập dữ liệu được giám sát (có khả năng tuần tự) bằng cách gắn thẻ ngược các lỗi (theo thời gian).

Chúng tôi đã tạo 2 tập dữ liệu của lớp - SupervisedDataSet. Chúng tôi đang sử dụng mô hình dữ liệu NOR như sau:

A B A NOR B
0 0 1
0 1 0
1 0 0
1 1 0

Mô hình dữ liệu trên được sử dụng để huấn luyện mạng.

norgate = SupervisedDataSet(2, 1)
# Add input and target values to dataset
# Values for NOR truth table
norgate.addSample((0, 0), (1,))
norgate.addSample((0, 1), (0,))
norgate.addSample((1, 0), (0,))
norgate.addSample((1, 1), (0,))

Sau đây là tập dữ liệu được sử dụng để kiểm tra:

# Create a dataset to be used for testing.
nortrain = SupervisedDataSet(2, 1)

# Add input and target values to dataset
# Values for NOR truth table
norgate.addSample((0, 0), (1,))
norgate.addSample((0, 1), (0,))
norgate.addSample((1, 0), (0,))
norgate.addSample((1, 1), (0,))

Trình huấn luyện được sử dụng như sau:

#Training the network with dataset norgate.
trainer = BackpropTrainer(nn, norgate)

# will run the loop 1000 times to train it.
for epoch in range(1000):
   trainer.train()

Để kiểm tra trên tập dữ liệu, chúng ta có thể sử dụng đoạn mã dưới đây:

trainer.testOnData(dataset=nortrain, verbose = True)

Đầu ra

python testnetwork.py

C:\pybrain\pybrain\src>python testnetwork.py
Testing on data:
('out: ', '[0.887 ]')
('correct:', '[1 ]')
error: 0.00637334
('out: ', '[0.149 ]')
('correct:', '[0 ]')
error: 0.01110338
('out: ', '[0.102 ]')
('correct:', '[0 ]')
error: 0.00522736
('out: ', '[-0.163]')
('correct:', '[0 ]')
error: 0.01328650
('All errors:', [0.006373344564625953, 0.01110338071737218, 0.005227359234093431
, 0.01328649974219942])
('Average error:', 0.008997646064572746)
('Max error:', 0.01328649974219942, 'Median error:', 0.01110338071737218)

Nếu bạn kiểm tra đầu ra, dữ liệu kiểm tra gần như khớp với tập dữ liệu chúng tôi đã cung cấp và do đó lỗi là 0,008.

Bây giờ chúng ta hãy thay đổi dữ liệu thử nghiệm và xem một lỗi trung bình. Chúng tôi đã thay đổi đầu ra như hình dưới đây -

Sau đây là tập dữ liệu được sử dụng để kiểm tra:

# Create a dataset to be used for testing.
nortrain = SupervisedDataSet(2, 1)

# Add input and target values to dataset
# Values for NOR truth table
norgate.addSample((0, 0), (0,))
norgate.addSample((0, 1), (1,))
norgate.addSample((1, 0), (1,))
norgate.addSample((1, 1), (0,))

Bây giờ hãy để chúng tôi kiểm tra nó.

Đầu ra

python testnework.py

C:\pybrain\pybrain\src>python testnetwork.py
Testing on data:
('out: ', '[0.988 ]')
('correct:', '[0 ]')
error: 0.48842978
('out: ', '[0.027 ]')
('correct:', '[1 ]')
error: 0.47382097
('out: ', '[0.021 ]')
('correct:', '[1 ]')
error: 0.47876379
('out: ', '[-0.04 ]')
('correct:', '[0 ]')
error: 0.00079160
('All errors:', [0.4884297811030845, 0.47382096780393873, 0.47876378995939756, 0
.0007915982149002194])
('Average error:', 0.3604515342703303)
('Max error:', 0.4884297811030845, 'Median error:', 0.47876378995939756)

Chúng tôi nhận được lỗi là 0,36, điều này cho thấy dữ liệu thử nghiệm của chúng tôi không hoàn toàn khớp với mạng được đào tạo.