Apache MXNet - KVStore et visualisation
Ce chapitre traite des packages python KVStore et de la visualisation.
Paquet KVStore
KV stores signifie magasin de valeurs clés. C'est un composant essentiel utilisé pour la formation multi-appareils. Cela est important car la communication des paramètres entre les appareils sur une seule ou plusieurs machines est transmise via un ou plusieurs serveurs avec un KVStore pour les paramètres.
Comprenons le fonctionnement de KVStore à l'aide des points suivants:
Chaque valeur dans KVStore est représentée par un key et un value.
Chaque tableau de paramètres du réseau se voit attribuer un key et les poids de ce tableau de paramètres sont référencés par value.
Après cela, les nœuds de travail pushdégradés après le traitement d'un lot. Ils aussipull poids mis à jour avant de traiter un nouveau lot.
En termes simples, nous pouvons dire que KVStore est un lieu de partage de données où chaque appareil peut insérer des données et extraire des données.
Données Push-In et Pull-Out
KVStore peut être considéré comme un objet unique partagé sur différents appareils tels que des GPU et des ordinateurs, où chaque appareil est capable de pousser des données et d'extraire des données.
Voici les étapes de mise en œuvre qui doivent être suivies par les appareils pour pousser les données et les extraire:
Étapes de mise en œuvre
Initialisation- La première étape consiste à initialiser les valeurs. Ici, pour notre exemple, nous allons initialiser une paire (int, NDArray) paire dans KVStrore et ensuite extraire les valeurs -
import mxnet as mx
kv = mx.kv.create('local') # create a local KVStore.
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
Output
Cela produit la sortie suivante -
[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]
Push, Aggregate, and Update - Une fois initialisé, nous pouvons pousser une nouvelle valeur dans KVStore avec la même forme à la clé -
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a)
print(a.asnumpy())
Output
La sortie est donnée ci-dessous -
[[8. 8. 8.]
[8. 8. 8.]
[8. 8. 8.]]
Les données utilisées pour pousser peuvent être stockées sur n'importe quel appareil tel que des GPU ou des ordinateurs. Nous pouvons également pousser plusieurs valeurs dans la même clé. Dans ce cas, le KVStore additionnera d'abord toutes ces valeurs, puis transmettra la valeur agrégée comme suit -
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
Output
Vous verrez la sortie suivante -
[[4. 4. 4.]
[4. 4. 4.]
[4. 4. 4.]]
Pour chaque push que vous avez appliqué, KVStore combinera la valeur poussée avec la valeur déjà stockée. Cela se fera à l'aide d'un programme de mise à jour. Ici, le programme de mise à jour par défaut est ASSIGN.
def update(key, input, stored):
print("update on key: %d" % key)
stored += input * 2
kv.set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
Output
Lorsque vous exécutez le code ci-dessus, vous devriez voir la sortie suivante -
[[4. 4. 4.]
[4. 4. 4.]
[4. 4. 4.]]
Example
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
Output
Ci-dessous est la sortie du code -
update on key: 3
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]
Pull - Comme pour Push, nous pouvons également tirer la valeur sur plusieurs appareils avec un seul appel comme suit -
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())
Output
La sortie est indiquée ci-dessous -
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]
Exemple d'implémentation complet
Vous trouverez ci-dessous l'exemple complet de mise en œuvre -
import mxnet as mx
kv = mx.kv.create('local')
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a) # pull out the value
print(a.asnumpy())
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
def update(key, input, stored):
print("update on key: %d" % key)
stored += input * 2
kv._set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())
Gestion des paires clé-valeur
Toutes les opérations que nous avons implémentées ci-dessus impliquent une seule clé, mais KVStore fournit également une interface pour a list of key-value pairs -
Pour un seul appareil
Voici un exemple pour montrer une interface KVStore pour une liste de paires clé-valeur pour un seul appareil -
keys = [5, 7, 9]
kv.init(keys, [mx.nd.ones(shape)]*len(keys))
kv.push(keys, [mx.nd.ones(shape)]*len(keys))
b = [mx.nd.zeros(shape)]*len(keys)
kv.pull(keys, out = b)
print(b[1].asnumpy())
Output
Vous recevrez la sortie suivante -
update on key: 5
update on key: 7
update on key: 9
[[3. 3. 3.]
[3. 3. 3.]
[3. 3. 3.]]
Pour plusieurs appareils
Voici un exemple pour montrer une interface KVStore pour une liste de paires clé-valeur pour plusieurs appareils -
b = [[mx.nd.ones(shape, ctx) for ctx in contexts]] * len(keys)
kv.push(keys, b)
kv.pull(keys, out = b)
print(b[1][1].asnumpy())
Output
Vous verrez la sortie suivante -
update on key: 5
update on key: 7
update on key: 9
[[11. 11. 11.]
[11. 11. 11.]
[11. 11. 11.]]
Package de visualisation
Le package de visualisation est le package Apache MXNet utilisé pour représenter le réseau neuronal (NN) sous la forme d'un graphe de calcul composé de nœuds et d'arêtes.
Visualisez le réseau neuronal
Dans l'exemple ci-dessous, nous utiliserons mx.viz.plot_networkpour visualiser le réseau neuronal. Les éléments suivants sont les prérequis pour cela -
Prerequisites
Cahier Jupyter
Bibliothèque Graphviz
Exemple d'implémentation
Dans l'exemple ci-dessous, nous allons visualiser un échantillon NN pour la factorisation matricielle linéaire -
import mxnet as mx
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')
# Set the dummy dimensions
k = 64
max_user = 100
max_item = 50
# The user feature lookup
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)
# The item feature lookup
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)
# predict by the inner product and then do sum
N_net = user * item
N_net = mx.symbol.sum_axis(data = N_net, axis = 1)
N_net = mx.symbol.Flatten(data = N_net)
# Defining the loss layer
N_net = mx.symbol.LinearRegressionOutput(data = N_net, label = score)
# Visualize the network
mx.viz.plot_network(N_net)