def criterion(input, labels): # criterion function must drop the <s> from the labels postprocessed_labels = sequence.slice(labels, 1, 0) # <s> A B C </s> --> A B C </s> z = model(input, postprocessed_labels) ce = cross_entropy_with_softmax(z, postprocessed_labels) errs = classification_error (z, postprocessed_labels) return (ce, errs)
def create_network(input_vocab_dim, label_vocab_dim): # network complexity; initially low for faster testing hidden_dim = 256 num_layers = 1 # Source and target inputs to the model input_seq_axis = Axis('inputAxis') label_seq_axis = Axis('labelAxis') raw_input = sequence.input(shape=(input_vocab_dim), sequence_axis=input_seq_axis, name='raw_input') raw_labels = sequence.input(shape=(label_vocab_dim), sequence_axis=label_seq_axis, name='raw_labels') # Instantiate the sequence to sequence translation model input_sequence = raw_input # Drop the sentence start token from the label, for decoder training label_sequence = sequence.slice(raw_labels, 1, 0) # <s> A B C </s> --> A B C </s> label_sentence_start = sequence.first(raw_labels) # <s> is_first_label = sequence.is_first(label_sequence) # <s> 0 0 0 ... label_sentence_start_scattered = sequence.scatter( label_sentence_start, is_first_label) # Encoder encoder_outputH = stabilize(input_sequence) for i in range(0, num_layers): (encoder_outputH, encoder_outputC) = LSTMP_component_with_self_stabilization( encoder_outputH.output, hidden_dim, hidden_dim, future_value, future_value) thought_vectorH = sequence.first(encoder_outputH) thought_vectorC = sequence.first(encoder_outputC) thought_vector_broadcastH = sequence.broadcast_as( thought_vectorH, label_sequence) thought_vector_broadcastC = sequence.broadcast_as( thought_vectorC, label_sequence) # Decoder decoder_history_hook = alias(label_sequence, name='decoder_history_hook') # copy label_sequence decoder_input = element_select(is_first_label, label_sentence_start_scattered, past_value( decoder_history_hook)) decoder_outputH = stabilize(decoder_input) for i in range(0, num_layers): if (i > 0): recurrence_hookH = past_value recurrence_hookC = past_value else: isFirst = sequence.is_first(label_sequence) recurrence_hookH = lambda operand: element_select( isFirst, thought_vector_broadcastH, past_value(operand)) recurrence_hookC = lambda operand: element_select( isFirst, thought_vector_broadcastC, past_value(operand)) (decoder_outputH, encoder_outputC) = LSTMP_component_with_self_stabilization( decoder_outputH.output, hidden_dim, hidden_dim, recurrence_hookH, recurrence_hookC) decoder_output = decoder_outputH # Softmax output layer z = linear_layer(stabilize(decoder_output), label_vocab_dim) # Criterion nodes ce = cross_entropy_with_softmax(z, label_sequence) errs = classification_error(z, label_sequence) # network output for decoder history net_output = hardmax(z) # make a clone of the graph where the ground truth is replaced by the network output ng = z.clone(CloneMethod.share, {decoder_history_hook.output : net_output.output}) return { 'raw_input' : raw_input, 'raw_labels' : raw_labels, 'ce' : ce, 'pe' : errs, 'ng' : ng, 'output': z }