Apache MXNet - KVStore и визуализация

В этой главе рассматриваются пакеты Python KVStore и визуализация.

Пакет KVStore

KV store означает магазин Key-Value. Это важный компонент, используемый для обучения с использованием нескольких устройств. Это важно, потому что передача параметров между устройствами на одной, а также на нескольких машинах передается через один или несколько серверов с KVStore для параметров.

Давайте разберемся в работе KVStore с помощью следующих пунктов:

  • Каждое значение в KVStore представлено key и value.

  • Каждому массиву параметров в сети назначается key и веса этого массива параметров обозначаются как value.

  • После этого рабочие узлы pushградиенты после обработки партии. Они такжеpull обновленные веса перед обработкой новой партии.

Проще говоря, мы можем сказать, что KVStore - это место для обмена данными, где каждое устройство может загружать и извлекать данные.

Загрузка и извлечение данных

KVStore можно рассматривать как единый объект, совместно используемый различными устройствами, такими как графические процессоры и компьютеры, где каждое устройство может загружать и извлекать данные.

Ниже приведены шаги реализации, которые должны выполняться устройствами для передачи и извлечения данных:

Этапы реализации

Initialisation- Первый шаг - инициализировать значения. В нашем примере мы инициализируем пару pair (int, NDArray) в KVStrore и после этого извлекаем значения -

import mxnet as mx
kv = mx.kv.create('local') # create a local KVStore.
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())

Output

Это дает следующий результат -

[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]

Push, Aggregate, and Update - После инициализации мы можем отправить новое значение в KVStore с той же формой для ключа -

kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a)
print(a.asnumpy())

Output

Результат приведен ниже -

[[8. 8. 8.]
 [8. 8. 8.]
 [8. 8. 8.]]

Данные, используемые для отправки, могут храниться на любом устройстве, например графическом процессоре или компьютере. Мы также можем поместить несколько значений в один ключ. В этом случае KVStore сначала суммирует все эти значения, а затем отправит агрегированное значение следующим образом:

contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())

Output

Вы увидите следующий вывод -

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

Для каждого примененного вами push-уведомления KVStore объединит отправленное значение с уже сохраненным значением. Это будет сделано с помощью апдейтера. Здесь средство обновления по умолчанию - ASSIGN.

def update(key, input, stored):
   print("update on key: %d" % key)
   
   stored += input * 2
kv.set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())

Output

Когда вы выполните приведенный выше код, вы должны увидеть следующий результат -

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

Example

kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())

Output

Ниже приведен вывод кода -

update on key: 3
[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

Pull - Как и в случае с Push, мы также можем передать значение на несколько устройств одним вызовом следующим образом:

b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

Output

Результат указан ниже -

[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

Полный пример реализации

Ниже приведен полный пример реализации -

import mxnet as mx
kv = mx.kv.create('local')
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a) # pull out the value
print(a.asnumpy())
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
def update(key, input, stored):
   print("update on key: %d" % key)
   stored += input * 2
kv._set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

Обработка пар ключ-значение

Все операции, которые мы реализовали выше, включают один ключ, но KVStore также предоставляет интерфейс для a list of key-value pairs -

Для одного устройства

Ниже приведен пример интерфейса KVStore для списка пар ключ-значение для одного устройства.

keys = [5, 7, 9]
kv.init(keys, [mx.nd.ones(shape)]*len(keys))
kv.push(keys, [mx.nd.ones(shape)]*len(keys))
b = [mx.nd.zeros(shape)]*len(keys)
kv.pull(keys, out = b)
print(b[1].asnumpy())

Output

Вы получите следующий вывод -

update on key: 5
update on key: 7
update on key: 9
[[3. 3. 3.]
 [3. 3. 3.]
 [3. 3. 3.]]

Для нескольких устройств

Ниже приведен пример интерфейса KVStore для списка пар ключ-значение для нескольких устройств.

b = [[mx.nd.ones(shape, ctx) for ctx in contexts]] * len(keys)
kv.push(keys, b)
kv.pull(keys, out = b)
print(b[1][1].asnumpy())

Output

Вы увидите следующий вывод -

update on key: 5
update on key: 7
update on key: 9
[[11. 11. 11.]
 [11. 11. 11.]
 [11. 11. 11.]]

Пакет визуализации

Пакет визуализации - это пакет Apache MXNet, используемый для представления нейронной сети (NN) в виде графа вычислений, состоящего из узлов и ребер.

Визуализируйте нейронную сеть

В приведенном ниже примере мы будем использовать mx.viz.plot_networkдля визуализации нейронной сети. Следующие условия являются предпосылками для этого -

Prerequisites

  • Блокнот Jupyter

  • Библиотека Graphviz

Пример реализации

В приведенном ниже примере мы визуализируем образец NN для линейной матричной факторизации -

import mxnet as mx
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')

# Set the dummy dimensions
k = 64
max_user = 100
max_item = 50

# The user feature lookup
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)

# The item feature lookup
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)

# predict by the inner product and then do sum
N_net = user * item
N_net = mx.symbol.sum_axis(data = N_net, axis = 1)
N_net = mx.symbol.Flatten(data = N_net)

# Defining the loss layer
N_net = mx.symbol.LinearRegressionOutput(data = N_net, label = score)

# Visualize the network
mx.viz.plot_network(N_net)