CNTK-시퀀스 분류
이 장에서는 CNTK의 시퀀스와 분류에 대해 자세히 알아 봅니다.
텐서
CNTK가 작동하는 개념은 tensor. 기본적으로 CNTK 입력, 출력 및 매개 변수는 다음과 같이 구성됩니다.tensors, 이는 종종 일반화 된 행렬로 간주됩니다. 모든 텐서에는rank −
순위 0의 텐서는 스칼라입니다.
순위 1의 텐서는 벡터입니다.
랭크 2의 텐서는 행렬입니다.
여기에서 이러한 다양한 차원을 axes.
정적 축 및 동적 축
이름에서 알 수 있듯이 정적 축의 길이는 네트워크 수명 내내 동일합니다. 반면에 동적 축의 길이는 인스턴스마다 다를 수 있습니다. 사실, 그들의 길이는 일반적으로 각 미니 배치가 제공되기 전에 알려지지 않았습니다.
동적 축은 텐서에 포함 된 숫자의 의미있는 그룹화도 정의하기 때문에 정적 축과 같습니다.
예
더 명확하게하기 위해 짧은 비디오 클립의 미니 배치가 CNTK에서 어떻게 표현되는지 살펴 보겠습니다. 비디오 클립의 해상도가 모두 640 * 480이라고 가정합니다. 또한 클립은 일반적으로 3 개의 채널로 인코딩되는 컬러로 촬영됩니다. 또한 우리의 미니 배치는 다음과 같은 것을 의미합니다.
길이가 각각 640, 480 및 3 인 3 개의 정적 축.
두 개의 동적 축; 비디오와 미니 배치 축의 길이.
즉, 미니 배치에 각각 240 프레임 길이의 16 개 동영상이있는 경우 다음과 같이 표시됩니다. 16*240*3*640*480 텐서.
CNTK에서 시퀀스 작업
Long-Short Term Memory Network에 대해 먼저 배움으로써 CNTK의 시퀀스를 이해합시다.
장단기 기억 네트워크 (LSTM)
Hochreiter & Schmidhuber는 장단기 기억 (LSTM) 네트워크를 도입했습니다. 오랫동안 사물을 기억하기 위해 기본 반복 레이어를 얻는 문제를 해결했습니다. LSTM의 아키텍처는 위의 다이어그램에 나와 있습니다. 보시다시피 입력 뉴런, 기억 세포 및 출력 뉴런이 있습니다. 소실 기울기 문제를 해결하기 위해 장기 단기 메모리 네트워크는 명시 적 메모리 셀 (이전 값 저장)과 다음 게이트를 사용합니다.
Forget gate− 이름에서 알 수 있듯이 이전 값을 잊어 버리도록 메모리 셀에 지시합니다. 메모리 셀은 게이트 즉 'forget gate'가 값을 잊으라고 지시 할 때까지 값을 저장합니다.
Input gate − 이름에서 알 수 있듯이 셀에 새로운 항목을 추가합니다.
Output gate − 이름에서 알 수 있듯이 출력 게이트는 벡터를 따라 셀에서 다음 은닉 상태로 전달할시기를 결정합니다.
CNTK에서 시퀀스 작업은 매우 쉽습니다. 다음 예제의 도움을 받아 보겠습니다.
import sys
import os
from cntk import Trainer, Axis
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs,\
INFINITELY_REPEAT
from cntk.learners import sgd, learning_parameter_schedule_per_sample
from cntk import input_variable, cross_entropy_with_softmax, \
classification_error, sequence
from cntk.logging import ProgressPrinter
from cntk.layers import Sequential, Embedding, Recurrence, LSTM, Dense
def create_reader(path, is_training, input_dim, label_dim):
return MinibatchSource(CTFDeserializer(path, StreamDefs(
features=StreamDef(field='x', shape=input_dim, is_sparse=True),
labels=StreamDef(field='y', shape=label_dim, is_sparse=False)
)), randomize=is_training,
max_sweeps=INFINITELY_REPEAT if is_training else 1)
def LSTM_sequence_classifier_net(input, num_output_classes, embedding_dim,
LSTM_dim, cell_dim):
lstm_classifier = Sequential([Embedding(embedding_dim),
Recurrence(LSTM(LSTM_dim, cell_dim)),
sequence.last,
Dense(num_output_classes)])
return lstm_classifier(input)
def train_sequence_classifier():
input_dim = 2000
cell_dim = 25
hidden_dim = 25
embedding_dim = 50
num_output_classes = 5
features = sequence.input_variable(shape=input_dim, is_sparse=True)
label = input_variable(num_output_classes)
classifier_output = LSTM_sequence_classifier_net(
features, num_output_classes, embedding_dim, hidden_dim, cell_dim)
ce = cross_entropy_with_softmax(classifier_output, label)
pe = classification_error(classifier_output, label)
rel_path = ("../../../Tests/EndToEndTests/Text/" +
"SequenceClassification/Data/Train.ctf")
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path)
reader = create_reader(path, True, input_dim, num_output_classes)
input_map = {
features: reader.streams.features,
label: reader.streams.labels
}
lr_per_sample = learning_parameter_schedule_per_sample(0.0005)
progress_printer = ProgressPrinter(0)
trainer = Trainer(classifier_output, (ce, pe),
sgd(classifier_output.parameters, lr=lr_per_sample),progress_printer)
minibatch_size = 200
for i in range(255):
mb = reader.next_minibatch(minibatch_size, input_map=input_map)
trainer.train_minibatch(mb)
evaluation_average = float(trainer.previous_minibatch_evaluation_average)
loss_average = float(trainer.previous_minibatch_loss_average)
return evaluation_average, loss_average
if __name__ == '__main__':
error, _ = train_sequence_classifier()
print(" error: %f" % error)
average since average since examples
loss last metric last
------------------------------------------------------
1.61 1.61 0.886 0.886 44
1.61 1.6 0.714 0.629 133
1.6 1.59 0.56 0.448 316
1.57 1.55 0.479 0.41 682
1.53 1.5 0.464 0.449 1379
1.46 1.4 0.453 0.441 2813
1.37 1.28 0.45 0.447 5679
1.3 1.23 0.448 0.447 11365
error: 0.333333
위 프로그램에 대한 자세한 설명은 특히 반복 신경망을 구성 할 때 다음 섹션에서 다룰 것입니다.