Apache MXNet - KVStore và Visualization
Chương này đề cập đến các gói python KVStore và trực quan hóa.
Gói KVStore
Cửa hàng KV là viết tắt của Key-Value store. Nó là thành phần quan trọng được sử dụng để đào tạo đa thiết bị. Điều quan trọng là bởi vì, việc truyền thông số giữa các thiết bị trên một máy cũng như trên nhiều máy được truyền qua một hoặc nhiều máy chủ có KVStore cho các tham số.
Hãy để chúng tôi hiểu hoạt động của KVStore với sự trợ giúp của các điểm sau:
Mỗi giá trị trong KVStore được đại diện bởi một key và một value.
Mỗi mảng tham số trong mạng được gán một key và trọng số của mảng tham số đó được tham chiếu bởi value.
Sau đó, các nút công nhân pushgradient sau khi xử lý một mẻ. Họ cũngpull trọng lượng cập nhật trước khi xử lý một lô mới.
Nói một cách dễ hiểu, chúng ta có thể nói rằng KVStore là nơi chia sẻ dữ liệu, nơi mỗi thiết bị có thể đẩy dữ liệu vào và kéo dữ liệu ra.
Dữ liệu đẩy vào và kéo ra
KVStore có thể được coi là một đối tượng duy nhất được chia sẻ trên các thiết bị khác nhau như GPU & máy tính, nơi mỗi thiết bị có thể đẩy dữ liệu vào và lấy dữ liệu ra.
Sau đây là các bước triển khai mà thiết bị cần phải tuân theo để đẩy dữ liệu vào và kéo dữ liệu ra:
Các bước thực hiện
Initialisation- Bước đầu tiên là khởi tạo các giá trị. Ở đây cho ví dụ của chúng tôi, chúng tôi sẽ khởi tạo một cặp (int, NDArray) vào KVStrore và sau đó kéo các giá trị ra -
import mxnet as mx
kv = mx.kv.create('local') # create a local KVStore.
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
Output
Điều này tạo ra kết quả sau:
[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]
Push, Aggregate, and Update - Sau khi khởi tạo, chúng ta có thể đẩy một giá trị mới vào KVStore có cùng hình dạng với khóa -
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a)
print(a.asnumpy())
Output
Đầu ra được đưa ra dưới đây -
[[8. 8. 8.]
[8. 8. 8.]
[8. 8. 8.]]
Dữ liệu được sử dụng để đẩy có thể được lưu trữ trên bất kỳ thiết bị nào như GPU hoặc máy tính. Chúng tôi cũng có thể đẩy nhiều giá trị vào cùng một khóa. Trong trường hợp này, KVStore trước tiên sẽ tính tổng tất cả các giá trị này và sau đó đẩy giá trị tổng hợp như sau:
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
Output
Bạn sẽ thấy kết quả sau:
[[4. 4. 4.]
[4. 4. 4.]
[4. 4. 4.]]
Đối với mỗi lần đẩy bạn đã áp dụng, KVStore sẽ kết hợp giá trị được đẩy với giá trị đã được lưu trữ. Nó sẽ được thực hiện với sự trợ giúp của một trình cập nhật. Ở đây, trình cập nhật mặc định là ASSIGN.
def update(key, input, stored):
print("update on key: %d" % key)
stored += input * 2
kv.set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
Output
Khi bạn thực thi đoạn mã trên, bạn sẽ thấy kết quả sau:
[[4. 4. 4.]
[4. 4. 4.]
[4. 4. 4.]]
Example
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
Output
Dưới đây là đầu ra của mã:
update on key: 3
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]
Pull - Giống như Push, chúng ta cũng có thể kéo giá trị lên một số thiết bị bằng một lệnh gọi như sau:
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())
Output
Đầu ra được nêu dưới đây -
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]
Hoàn thành ví dụ triển khai
Dưới đây là ví dụ triển khai đầy đủ -
import mxnet as mx
kv = mx.kv.create('local')
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a) # pull out the value
print(a.asnumpy())
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
def update(key, input, stored):
print("update on key: %d" % key)
stored += input * 2
kv._set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())
Xử lý các cặp khóa-giá trị
Tất cả các hoạt động chúng tôi đã triển khai ở trên liên quan đến một khóa duy nhất, nhưng KVStore cũng cung cấp một giao diện cho a list of key-value pairs -
Đối với một thiết bị duy nhất
Sau đây là một ví dụ để hiển thị giao diện KVStore cho danh sách các cặp khóa-giá trị cho một thiết bị:
keys = [5, 7, 9]
kv.init(keys, [mx.nd.ones(shape)]*len(keys))
kv.push(keys, [mx.nd.ones(shape)]*len(keys))
b = [mx.nd.zeros(shape)]*len(keys)
kv.pull(keys, out = b)
print(b[1].asnumpy())
Output
Bạn sẽ nhận được kết quả sau:
update on key: 5
update on key: 7
update on key: 9
[[3. 3. 3.]
[3. 3. 3.]
[3. 3. 3.]]
Đối với nhiều thiết bị
Sau đây là một ví dụ để hiển thị giao diện KVStore cho danh sách các cặp khóa-giá trị cho nhiều thiết bị -
b = [[mx.nd.ones(shape, ctx) for ctx in contexts]] * len(keys)
kv.push(keys, b)
kv.pull(keys, out = b)
print(b[1][1].asnumpy())
Output
Bạn sẽ thấy kết quả sau:
update on key: 5
update on key: 7
update on key: 9
[[11. 11. 11.]
[11. 11. 11.]
[11. 11. 11.]]
Gói hình ảnh hóa
Gói trực quan hóa là gói Apache MXNet được sử dụng để biểu diễn mạng nơ-ron (NN) dưới dạng đồ thị tính toán bao gồm các nút và các cạnh.
Trực quan hóa mạng thần kinh
Trong ví dụ dưới đây, chúng tôi sẽ sử dụng mx.viz.plot_networkđể hình dung mạng nơ-ron. Theo dõi là điều kiện tiên quyết cho điều này -
Prerequisites
Sổ ghi chép Jupyter
Thư viện Graphviz
Ví dụ triển khai
Trong ví dụ dưới đây, chúng ta sẽ trực quan hóa một NN mẫu cho ma trận tuyến tính thừa kế -
import mxnet as mx
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')
# Set the dummy dimensions
k = 64
max_user = 100
max_item = 50
# The user feature lookup
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)
# The item feature lookup
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)
# predict by the inner product and then do sum
N_net = user * item
N_net = mx.symbol.sum_axis(data = N_net, axis = 1)
N_net = mx.symbol.Flatten(data = N_net)
# Defining the loss layer
N_net = mx.symbol.LinearRegressionOutput(data = N_net, label = score)
# Visualize the network
mx.viz.plot_network(N_net)