Exemple #1
0
def main(c):
    ''' params:
            c: config dictionary
    '''

    # Data ---------------------------------------------------------------------------------------------------
    data_portion = None  # 2 * batch_size
    train_set = Dstc2('data/dstc2/data.dstc2.train.json',
                      sample_unk=0.01,
                      first_n=data_portion)
    valid_set = Dstc2('data/dstc2/data.dstc2.dev.json',
                      first_n=data_portion,
                      sample_unk=0,
                      max_dial_len=train_set.max_dial_len,
                      words_vocab=train_set.words_vocab,
                      labels_vocab=train_set.labels_vocab,
                      labels_vocab_separate=train_set.labels_vocab_separate)
    test_set = Dstc2('data/dstc2/data.dstc2.test.json',
                     first_n=data_portion,
                     sample_unk=0,
                     max_dial_len=train_set.max_dial_len,
                     words_vocab=train_set.words_vocab,
                     labels_vocab=train_set.labels_vocab,
                     labels_vocab_separate=train_set.labels_vocab_separate)

    stats(train_set, valid_set, test_set)

    vocab_size = len(train_set.words_vocab)
    output_dim = max(np.unique(train_set.labels)) + 1
    n_train_batches = len(train_set.dialogs) // c.batch_size

    # output dimensions for each separate label
    output_dims = []
    for i in range(3):
        o_d = max(np.unique(train_set.labels_separate[:, :, i])) + 1
        output_dims.append(o_d)

    # Model -----------------------------------------------------------------------------------------------------
    logging.info('Creating model')
    input_bt = tf.placeholder('int32', [c.batch_size, train_set.max_turn_len],
                              name='input')
    turn_lens_b = tf.placeholder('int32', [c.batch_size], name='turn_lens')
    mask_b = tf.placeholder('int32', [c.batch_size], name='dial_mask')
    # labels_b = tf.placeholder('int64', [c.batch_size], name='labels')
    # onehot_labels_bo = tf.one_hot(indices=labels_b,
    #                               depth=output_dim,
    #                               on_value=1.0,
    #                               off_value=0.0,
    #                               axis=-1)

    # separate labels and their onehots
    labels0_b, onehot_labels0_bo0 = get_labels_with_onehot(
        c.batch_size, output_dims[0], 'labels0')
    labels1_b, onehot_labels1_bo1 = get_labels_with_onehot(
        c.batch_size, output_dims[1], 'labels1')
    labels2_b, onehot_labels2_bo2 = get_labels_with_onehot(
        c.batch_size, output_dims[2], 'labels2')

    is_first_turn = tf.placeholder(tf.bool)
    gru = GRUCell(c.hidden_state_dim)

    embeddings_we = tf.get_variable(
        'word_embeddings',
        initializer=tf.random_uniform([vocab_size, c.embedding_dim], -1.0,
                                      1.0),
        trainable=False)
    embedded_input_bte = tf.nn.embedding_lookup(embeddings_we, input_bt)
    dialog_state_before_turn = tf.get_variable(
        'dialog_state_before_turn',
        initializer=tf.zeros([c.batch_size, c.hidden_state_dim],
                             dtype='float32'),
        trainable=False)

    before_state_bh = cond(
        is_first_turn, lambda: gru.zero_state(c.batch_size, dtype='float32'),
        lambda: dialog_state_before_turn)

    inputs = [
        tf.squeeze(i, squeeze_dims=[1])
        for i in tf.split(1, train_set.max_turn_len, embedded_input_bte)
    ]

    outputs, state_bh = tf.nn.rnn(cell=gru,
                                  inputs=inputs,
                                  initial_state=before_state_bh,
                                  sequence_length=turn_lens_b,
                                  dtype=tf.float32)

    dialog_state_before_turn.assign(state_bh)

    # projection_ho = tf.get_variable('project2labels',
    #                                 initializer=tf.random_uniform([c.hidden_state_dim, output_dim], -1.0, 1.0))

    # logits_bo = tf.matmul(state_bh, projection_ho)
    # tf.histogram_summary('logits', logits_bo)

    # probabilities_bo = tf.nn.softmax(logits_bo)
    # tf.histogram_summary('probabilities', probabilities_bo)

    # logits and probabilites and predictions from hidden state
    logits_bo0, probabilities_bo0, predict_b0 = get_logits_and_probabilities(
        state_bh, c.hidden_state_dim, output_dims[0], 'labels0')
    logits_bo1, probabilities_bo1, predict_b1 = get_logits_and_probabilities(
        state_bh, c.hidden_state_dim, output_dims[1], 'labels1')
    logits_bo2, probabilities_bo2, predict_b2 = get_logits_and_probabilities(
        state_bh, c.hidden_state_dim, output_dims[2], 'labels2')

    float_mask_b = tf.cast(mask_b, 'float32')

    # loss = tf.reduce_sum(tf.mul(float_mask_b, x_entropy(logits_bo, onehot_labels_bo))) / tf.reduce_sum(float_mask_b)
    # tf.scalar_summary('CCE loss', loss)

    # losses
    loss_0 = tf.reduce_sum(
        tf.mul(float_mask_b, x_entropy(
            logits_bo0, onehot_labels0_bo0))) / tf.reduce_sum(float_mask_b)
    loss_1 = tf.reduce_sum(
        tf.mul(float_mask_b, x_entropy(
            logits_bo1, onehot_labels1_bo1))) / tf.reduce_sum(float_mask_b)
    loss_2 = tf.reduce_sum(
        tf.mul(float_mask_b, x_entropy(
            logits_bo2, onehot_labels2_bo2))) / tf.reduce_sum(float_mask_b)
    loss = loss_0 + loss_1 + loss_2
    tf.scalar_summary('CCE loss', loss)

    # predict_b = tf.argmax(logits_bo, 1)
    # correct = tf.cast(tf.equal(predict_b, labels_b), 'float32')
    # accuracy = tf.reduce_sum(tf.mul(correct, float_mask_b)) / tf.reduce_sum(float_mask_b)
    # tf.scalar_summary('Accuracy', accuracy)

    # correct
    correct_0 = tf.cast(tf.equal(predict_b0, labels0_b), 'float32')
    correct_1 = tf.cast(tf.equal(predict_b1, labels1_b), 'float32')
    correct_2 = tf.cast(tf.equal(predict_b2, labels2_b), 'float32')
    correct_all = tf.mul(tf.mul(correct_0, correct_1), correct_2)

    # accuracies
    accuracy_0 = get_accuracy(correct_0, float_mask_b)
    accuracy_1 = get_accuracy(correct_1, float_mask_b)
    accuracy_2 = get_accuracy(correct_2, float_mask_b)
    accuracy_all = get_accuracy(correct_all, float_mask_b)
    tf.scalar_summary('Accuracy all', accuracy_all)
    tf.scalar_summary('Accuracy label 0', accuracy_0)
    tf.scalar_summary('Accuracy label 1', accuracy_1)
    tf.scalar_summary('Accuracy label 2', accuracy_2)

    tb_info = tf.merge_all_summaries()

    # Optimizer  -----------------------------------------------------------------------------------------------------
    logging.info('Creating optimizer')
    optimizer = tf.train.AdamOptimizer(c.learning_rate)
    logging.info('Creating train_op')
    train_op = optimizer.minimize(loss)

    # Session  -----------------------------------------------------------------------------------------------------
    logging.info('Creating session')
    sess = tf.Session()
    logging.info('Initing variables')
    init = tf.initialize_all_variables()
    logging.info('Running session')
    sess.run(init)

    # TB ---------------------------------------------------------------------------------------------------------
    logging.info('See stats via tensorboard: $ tensorboard --logdir %s',
                 c.log_dir)
    train_writer = tf.train.SummaryWriter(c.log_dir, sess.graph)

    # Train ---------------------------------------------------------------------------------------------------------
    train_summary = None
    for e in range(c.epochs):
        logging.info('------------------------------')
        logging.info('Epoch %d', e)

        total_loss = 0
        total_acc = 0
        batch_count = 0
        for bid, (dialogs_bTt, lengths_bT, labels0_bT, labels1_bT, labels2_bT,
                  masks_bT) in enumerate(next_batch(train_set, c.batch_size)):
            turn_loss = 0
            turn_acc = 0
            n_turns = 0
            first_run = True
            for (turn_bt, label0_b, label1_b, label2_b, lengths_b,
                 masks_b) in zip(dialogs_bTt.transpose([1, 0, 2]),
                                 labels0_bT.transpose([1, 0]),
                                 labels1_bT.transpose([1, 0]),
                                 labels2_bT.transpose([1, 0]),
                                 lengths_bT.transpose([1, 0]),
                                 masks_bT.transpose([1, 0])):
                if sum(masks_b) == 0:
                    break

                _, batch_loss, batch_accuracy, train_summary = sess.run(
                    [train_op, loss, accuracy_all, tb_info],
                    feed_dict={
                        input_bt: turn_bt,
                        turn_lens_b: lengths_b,
                        mask_b: masks_b,
                        labels0_b: label0_b,
                        labels1_b: label1_b,
                        labels2_b: label2_b,
                        is_first_turn: first_run
                    })
                first_run = False
                turn_loss += batch_loss
                turn_acc += batch_accuracy
                n_turns += 1

            total_loss += turn_loss / n_turns
            total_acc += turn_acc / n_turns
            batch_count += 1

            logging.info('Batch %d/%d\r', bid, n_train_batches)

        train_writer.add_summary(train_summary, e)
        logging.info('Train cost %f', total_loss / batch_count)
        logging.info('Train accuracy: %f', total_acc / batch_count)

        def monitor_stream(work_set, name):
            total_loss = 0
            total_acc = 0
            n_valid_batches = 0
            for bid, (dialogs_bTt, lengths_bT, labels0_bT, labels1_bT,
                      labels2_bT, masks_bT) in enumerate(
                          next_batch(work_set, c.batch_size)):
                turn_loss = 0
                turn_acc = 0
                n_turns = 0
                first_run = True
                for (turn_bt, label0_b, label1_b, label2_b, lengths_b,
                     masks_b) in zip(dialogs_bTt.transpose([1, 0, 2]),
                                     labels0_bT.transpose([1, 0]),
                                     labels1_bT.transpose([1, 0]),
                                     labels2_bT.transpose([1, 0]),
                                     lengths_bT.transpose([1, 0]),
                                     masks_bT.transpose([1, 0])):
                    if sum(masks_b) == 0:
                        break

                    input = np.pad(turn_bt, ((0, 0), (0, train_set.max_turn_len-turn_bt.shape[1])),
                                   'constant', constant_values=0) if train_set.max_turn_len > turn_bt.shape[1]\
                        else turn_bt

                    batch_loss, batch_acc, valid_summary = sess.run(
                        [loss, accuracy_all, tb_info],
                        feed_dict={
                            input_bt: input,
                            turn_lens_b: lengths_b,
                            labels0_b: label0_b,
                            labels1_b: label1_b,
                            labels2_b: label2_b,
                            mask_b: masks_b,
                            is_first_turn: first_run
                        })
                    turn_loss += batch_loss
                    turn_acc += batch_acc
                    first_run = False
                    n_turns += 1

                total_loss += turn_loss / n_turns
                total_acc += turn_acc / n_turns
                n_valid_batches += 1

            logging.info('%s cost: %f', name, total_loss / n_valid_batches)
            logging.info('%s accuracy: %f', name, total_acc / n_valid_batches)

        monitor_stream(valid_set, 'Valid')
        monitor_stream(test_set, 'Test')
Exemple #2
0
    return xs, ys


rnn_cell = GRUCell(HIDDEN_SIZE)

input_placeholder = tf.placeholder(dtype=tf.int32,
                                   shape=[None, SEQ_LEN - 1],
                                   name="input")

target_placeholder = tf.placeholder(dtype=tf.int32,
                                    shape=[None, SEQ_LEN - 1],
                                    name="target")
learning_rate_placeholder = tf.placeholder(dtype=tf.float32, shape=None)

outputs = []
state = rnn_cell.zero_state(BATCH_SIZE, dtype=tf.float32)
states = []

with tf.variable_scope("RNN"):
    with tf.variable_scope("embedding"):
        embedding_matrix = tf.get_variable(
            "weights",
            shape=[VOCABULARY_SIZE, HIDDEN_SIZE],
            initializer=tf.random_normal_initializer())
    with tf.variable_scope("softmax"):
        softmax_w = tf.get_variable("weight",
                                    shape=[HIDDEN_SIZE, VOCABULARY_SIZE],
                                    initializer=tf.random_normal_initializer())
        softmax_b = tf.get_variable("bias",
                                    shape=[VOCABULARY_SIZE],
                                    initializer=tf.constant_initializer(0.1))
Exemple #3
0
def main():
    # Config -----------------------------------------------------------------------------------------------------
    learning_rate = 0.005
    batch_size = 16
    epochs = 50
    hidden_state_dim = 200
    embedding_dim = 300
    log_dir = 'log'

    # Data ---------------------------------------------------------------------------------------------------
    data_portion =  2 * batch_size
    train_set = Dstc2('../data/dstc2/data.dstc2.train.json', sample_unk=0.01, first_n=data_portion)
    valid_set = Dstc2('../data/dstc2/data.dstc2.dev.json', first_n=data_portion, sample_unk=0, max_dial_len=train_set.max_dial_len, words_vocab=train_set.words_vocab, labels_vocab=train_set.labels_vocab)
    test_set = Dstc2('../data/dstc2/data.dstc2.test.json', first_n=data_portion, sample_unk=0, max_dial_len=train_set.max_dial_len, words_vocab=train_set.words_vocab, labels_vocab=train_set.labels_vocab)

    vocab_size = len(train_set.words_vocab)
    output_dim = max(np.unique(train_set.labels)) + 1
    n_train_batches = len(train_set.dialogs) // batch_size

    # Model -----------------------------------------------------------------------------------------------------
    logging.info('Creating model')
    input_bt = tf.placeholder('int32', [batch_size, train_set.max_turn_len], name='input')
    turn_lens_b = tf.placeholder('int32', [batch_size], name='turn_lens')
    mask_b = tf.placeholder('int32', [batch_size], name='dial_mask')
    # mask_bT = lengths2mask2d(dial_lens_b, train_set.max_dial_len)
    labels_b = tf.placeholder('int64', [batch_size], name='labels')
    onehot_labels_bo = tf.one_hot(indices=labels_b,
                                  depth=output_dim,
                                  on_value=1.0,
                                  off_value=0.0,
                                  axis=-1)
    is_first_turn = tf.placeholder(tf.bool)
    gru = GRUCell(hidden_state_dim)
    mlp_hidden_layer_dim = 50
    mlp_input2hidden_W = tf.get_variable('in2hid', initializer=tf.random_normal([hidden_state_dim, mlp_hidden_layer_dim]))
    mlp_input2hidden_B = tf.Variable(tf.random_normal([mlp_hidden_layer_dim]))
    mlp_hidden2output_W = tf.get_variable('hid2out', initializer=tf.random_normal([mlp_hidden_layer_dim, output_dim]))
    mlp_hidden2output_B = tf.Variable(tf.random_normal([output_dim]))

    embeddings_we = tf.get_variable('word_embeddings', initializer=tf.random_uniform([vocab_size, embedding_dim], -1.0, 1.0))
    embedded_input_bte = tf.nn.embedding_lookup(embeddings_we, input_bt)
    dialog_state_before_turn = tf.get_variable('dialog_state_before_turn', initializer=tf.zeros([batch_size, hidden_state_dim], dtype='float32'), trainable=False)

    before_state_bh = cond(is_first_turn,
        lambda: gru.zero_state(batch_size, dtype='float32'),
        lambda: dialog_state_before_turn)

    inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in tf.split(1, train_set.max_turn_len, embedded_input_bte)]

    outputs, state_bh = tf.nn.rnn(cell=gru,
            inputs=inputs,
            initial_state=before_state_bh,
            sequence_length=turn_lens_b,
            dtype=tf.float32)

    # state_tbh = scan(fn=lambda last_state_bh, curr_input_bte: gru(curr_input_bte, last_state_bh)[1],
    #                 elems=tf.transpose(embedded_input_bte, perm=[1, 0, 2]),
    #                 initializer=before_state_bh)

    # state_bh = state_tbh[state_tbh.get_shape()[0]-1, :, :]
    dialog_state_before_turn.assign(state_bh)

    projection_ho = tf.get_variable('project2labels',
                                    initializer=tf.random_uniform([hidden_state_dim, output_dim], -1.0, 1.0))


    logits_bo = tf.matmul(state_bh, projection_ho)
    # hidden =  tf.add(tf.matmul(state_bh, mlp_input2hidden_W), mlp_input2hidden_B)
    # logits_bo = tf.add(tf.matmul(hidden, mlp_hidden2output_W), mlp_hidden2output_B
    tf.histogram_summary('logits', logits_bo)

    probabilities_bo = tf.nn.softmax(logits_bo)
    tf.histogram_summary('probabilities', probabilities_bo)

    float_mask_b = tf.cast(mask_b,'float32')
    # loss = tf.matmul(tf.expand_dims(tf.cast(mask_b, 'float32'), 0), tf.nn.softmax_cross_entropy_with_logits(logits_bo, onehot_labels_bo)) / tf.reduce_sum(mask_b)
    loss = tf.reduce_sum(tf.mul(float_mask_b, tf.nn.softmax_cross_entropy_with_logits(logits_bo, onehot_labels_bo))) / tf.reduce_sum(float_mask_b)


    tf.scalar_summary('CCE loss', loss)

    predict_b = tf.argmax(logits_bo, 1)
    correct = tf.cast(tf.equal(predict_b, labels_b), 'float32')
    accuracy = tf.reduce_sum(tf.mul(correct, float_mask_b)) / tf.reduce_sum(float_mask_b)

    tf.scalar_summary('Accuracy', accuracy)
    tb_info = tf.merge_all_summaries()

    # Optimizer  -----------------------------------------------------------------------------------------------------
    logging.info('Creating optimizer')
    optimizer = tf.train.AdamOptimizer(learning_rate)
    logging.info('Creating train_op')
    train_op = optimizer.minimize(loss)
    # Session  -----------------------------------------------------------------------------------------------------
    logging.info('Creating session')
    sess = tf.Session()
    logging.info('Initing variables')
    init = tf.initialize_all_variables()
    logging.info('Running session')
    sess.run(init)

    # TB ---------------------------------------------------------------------------------------------------------
    logging.info('See stats via tensorboard: $ tensorboard --logdir %s', log_dir)
    train_writer = tf.train.SummaryWriter(log_dir, sess.graph)

    # Train ---------------------------------------------------------------------------------------------------------
    train_summary = None
    for e in range(epochs):
        logging.info('------------------------------')
        logging.info('Epoch %d', e)

        total_loss = 0
        total_acc = 0
        batch_count = 0
        for bid, (dialogs_bTt, lengths_bT, labels_bT, masks_bT) in enumerate(next_batch(train_set, batch_size)):
            turn_loss = 0
            turn_acc = 0
            n_turns = 0
            first_run = True
            for (turn_bt, label_b, lengths_b, masks_b) in zip(dialogs_bTt.transpose([1,0,2]), labels_bT.transpose([1,0]), lengths_bT.transpose([1,0]), masks_bT.transpose([1,0])):
                if sum(masks_b) == 0:
                    break
                _, batch_loss, batch_accuracy, train_summary = sess.run([train_op, loss, accuracy, tb_info], feed_dict={input_bt: turn_bt,
                                                                                              turn_lens_b: lengths_b,
                                                                                              mask_b: masks_b,
                                                                                              labels_b: label_b,
                                                                                              is_first_turn:first_run})
                first_run = False
                turn_loss += batch_loss
                turn_acc += batch_accuracy
                n_turns += 1
            total_loss += turn_loss / n_turns
            total_acc += turn_acc / n_turns
            batch_count += 1
            logging.info('Batch %d/%d\r', bid, n_train_batches)

        train_writer.add_summary(train_summary, e)
        logging.info('Average train cost %f', total_loss / batch_count)
        logging.info('Average train accuracy: %f', total_acc / batch_count)

        def monitor_stream(work_set, name):
            total_loss = 0
            total_acc = 0
            n_valid_batches = 0
            for bid, (dialogs_bTt, lengths_bT, labels_bT, masks_bT) in enumerate(next_batch(work_set, batch_size)):
                turn_loss = 0
                turn_acc = 0
                n_turns = 0
                first_run = True
                for (turn_bt, label_b, lengths_b, masks_b) in zip(dialogs_bTt.transpose([1,0,2]), labels_bT.transpose([1,0]), lengths_bT.transpose([1,0]), masks_bT.transpose([1,0])):
                    if sum(masks_b) == 0:
                        break
                    input = np.pad(turn_bt, ((0,0), (0, train_set.max_turn_len-turn_bt.shape[1])), 'constant', constant_values=0) if train_set.max_turn_len > turn_bt.shape[1] else turn_bt
                    predictions, batch_loss, batch_acc, valid_summary = sess.run([predict_b, loss, accuracy, tb_info], feed_dict={input_bt: input,
                                                                    turn_lens_b: lengths_b,
                                                                    labels_b: label_b,
                                                                    mask_b: masks_b,
                                                                    is_first_turn:first_run})
                    turn_loss += batch_loss
                    turn_acc += batch_acc
                    first_run = False
                    n_turns += 1
                total_loss += turn_loss / n_turns
                total_acc += turn_acc / n_turns
                n_valid_batches += 1

            logging.info('%s cost: %f', name, total_loss/n_valid_batches)
            logging.info('%s accuracy: %f', name, total_acc/n_valid_batches)

        monitor_stream(valid_set, 'Valid')
        monitor_stream(test_set, 'Test')
Exemple #4
0
def main(c):
    ''' params:
            c: config dictionary
    '''

    # Data ---------------------------------------------------------------------------------------------------
    data_portion = None  # 2 * batch_size
    train_set = Dstc2('data/dstc2/data.dstc2.train.json',
                      sample_unk=0.01,
                      first_n=data_portion)
    valid_set = Dstc2('data/dstc2/data.dstc2.dev.json',
                      first_n=data_portion,
                      sample_unk=0,
                      max_dial_len=train_set.max_dial_len,
                      words_vocab=train_set.words_vocab,
                      labels_vocab=train_set.labels_vocab)
    test_set = Dstc2('data/dstc2/data.dstc2.test.json',
                     first_n=data_portion,
                     sample_unk=0,
                     max_dial_len=train_set.max_dial_len,
                     words_vocab=train_set.words_vocab,
                     labels_vocab=train_set.labels_vocab)

    stats(train_set, valid_set, test_set)

    vocab_size = len(train_set.words_vocab)
    output_dim = max(np.unique(train_set.labels)) + 1
    n_train_batches = len(train_set.dialogs) // c.batch_size

    # Model -----------------------------------------------------------------------------------------------------
    logging.info('Creating model')
    input_bt = tf.placeholder('int32', [c.batch_size, train_set.max_turn_len],
                              name='input')
    turn_lens_b = tf.placeholder('int32', [c.batch_size], name='turn_lens')
    mask_b = tf.placeholder('int32', [c.batch_size], name='dial_mask')
    labels_b = tf.placeholder('int64', [c.batch_size], name='labels')
    onehot_labels_bo = tf.one_hot(indices=labels_b,
                                  depth=output_dim,
                                  on_value=1.0,
                                  off_value=0.0,
                                  axis=-1)
    is_first_turn = tf.placeholder(tf.bool)
    gru = GRUCell(c.hidden_state_dim)

    embeddings_we = tf.get_variable(
        'word_embeddings',
        initializer=tf.random_uniform([vocab_size, c.embedding_dim], -1.0,
                                      1.0))
    embedded_input_bte = tf.nn.embedding_lookup(embeddings_we, input_bt)
    dialog_state_before_turn = tf.get_variable(
        'dialog_state_before_turn',
        initializer=tf.zeros([c.batch_size, c.hidden_state_dim],
                             dtype='float32'),
        trainable=False)

    before_state_bh = cond(
        is_first_turn, lambda: gru.zero_state(c.batch_size, dtype='float32'),
        lambda: dialog_state_before_turn)

    inputs = [
        tf.squeeze(i, squeeze_dims=[1])
        for i in tf.split(1, train_set.max_turn_len, embedded_input_bte)
    ]

    outputs, state_bh = tf.nn.rnn(cell=gru,
                                  inputs=inputs,
                                  initial_state=before_state_bh,
                                  sequence_length=turn_lens_b,
                                  dtype=tf.float32)

    dialog_state_before_turn.assign(state_bh)
    projection_ho = tf.get_variable(
        'project2labels',
        initializer=tf.random_uniform([c.hidden_state_dim, output_dim], -1.0,
                                      1.0))

    logits_bo = tf.matmul(state_bh, projection_ho)
    tf.histogram_summary('logits', logits_bo)

    probabilities_bo = tf.nn.softmax(logits_bo)
    tf.histogram_summary('probabilities', probabilities_bo)

    float_mask_b = tf.cast(mask_b, 'float32')
    loss = tf.reduce_sum(
        tf.mul(float_mask_b, x_entropy(
            logits_bo, onehot_labels_bo))) / tf.reduce_sum(float_mask_b)
    tf.scalar_summary('CCE loss', loss)

    predict_b = tf.argmax(logits_bo, 1)
    correct = tf.cast(tf.equal(predict_b, labels_b), 'float32')
    accuracy = tf.reduce_sum(tf.mul(
        correct, float_mask_b)) / tf.reduce_sum(float_mask_b)
    tf.scalar_summary('Accuracy', accuracy)

    tb_info = tf.merge_all_summaries()

    # Optimizer  -----------------------------------------------------------------------------------------------------
    logging.info('Creating optimizer')
    optimizer = tf.train.AdamOptimizer(c.learning_rate)
    logging.info('Creating train_op')
    train_op = optimizer.minimize(loss)

    # Session  -----------------------------------------------------------------------------------------------------
    logging.info('Creating session')
    sess = tf.Session()
    logging.info('Initing variables')
    init = tf.initialize_all_variables()
    logging.info('Running session')
    sess.run(init)

    # TB ---------------------------------------------------------------------------------------------------------
    logging.info('See stats via tensorboard: $ tensorboard --logdir %s',
                 c.log_dir)
    train_writer = tf.train.SummaryWriter(c.log_dir, sess.graph)

    # Train ---------------------------------------------------------------------------------------------------------
    train_summary = None
    step, stopper = 0, EarlyStopper(c.nbest_models, c.not_change_limit, c.name)
    try:
        for e in range(c.epochs):
            logging.info('------------------------------')
            logging.info('Epoch %d', e)

            total_loss = 0
            total_acc = 0
            batch_count = 0
            for bid, (dialogs_bTt, lengths_bT, labels_bT,
                      masks_bT) in enumerate(
                          next_batch(train_set, c.batch_size)):
                turn_loss = 0
                turn_acc = 0
                n_turns = 0
                first_run = True
                for (turn_bt, label_b, lengths_b,
                     masks_b) in zip(dialogs_bTt.transpose([1, 0, 2]),
                                     labels_bT.transpose([1, 0]),
                                     lengths_bT.transpose([1, 0]),
                                     masks_bT.transpose([1, 0])):
                    if sum(masks_b) == 0:
                        break

                    _, batch_loss, batch_accuracy, train_summary = sess.run(
                        [train_op, loss, accuracy, tb_info],
                        feed_dict={
                            input_bt: turn_bt,
                            turn_lens_b: lengths_b,
                            mask_b: masks_b,
                            labels_b: label_b,
                            is_first_turn: first_run
                        })
                    first_run = False
                    turn_loss += batch_loss
                    turn_acc += batch_accuracy
                    n_turns += 1
                    step += 1

                total_loss += turn_loss / n_turns
                total_acc += turn_acc / n_turns
                batch_count += 1

                logging.info('Batch %d/%d\r', bid, n_train_batches)

            train_writer.add_summary(train_summary, e)
            logging.info('Train cost %f', total_loss / batch_count)
            logging.info('Train accuracy: %f', total_acc / batch_count)

            def monitor_stream(work_set, name):
                total_loss = 0
                total_acc = 0
                n_valid_batches = 0
                for bid, (dialogs_bTt, lengths_bT, labels_bT,
                          masks_bT) in enumerate(
                              next_batch(work_set, c.batch_size)):
                    turn_loss = 0
                    turn_acc = 0
                    n_turns = 0
                    first_run = True
                    for (turn_bt, label_b, lengths_b,
                         masks_b) in zip(dialogs_bTt.transpose([1, 0, 2]),
                                         labels_bT.transpose([1, 0]),
                                         lengths_bT.transpose([1, 0]),
                                         masks_bT.transpose([1, 0])):
                        if sum(masks_b) == 0:
                            break

                        input = np.pad(turn_bt, ((0, 0), (0, train_set.max_turn_len-turn_bt.shape[1])),
                                       'constant', constant_values=0) if train_set.max_turn_len > turn_bt.shape[1]\
                            else turn_bt

                        predictions, batch_loss, batch_acc, valid_summary = sess.run(
                            [predict_b, loss, accuracy, tb_info],
                            feed_dict={
                                input_bt: input,
                                turn_lens_b: lengths_b,
                                labels_b: label_b,
                                mask_b: masks_b,
                                is_first_turn: first_run
                            })
                        turn_loss += batch_loss
                        turn_acc += batch_acc
                        first_run = False
                        n_turns += 1

                    total_loss += turn_loss / n_turns
                    total_acc += turn_acc / n_turns
                    n_valid_batches += 1

                logging.info('%s cost: %f', name, total_loss / n_valid_batches)
                logging.info('%s accuracy: %f', name,
                             total_acc / n_valid_batches)
                return total_loss / n_valid_batches

            stopper_reward = monitor_stream(valid_set, 'Valid')
            monitor_stream(test_set, 'Test')
            if not stopper.save_and_check(stopper_reward, step, sess):
                raise RuntimeError('Training not improving on dev set')
    finally:
        logging.info(
            'Training stopped after %7d steps and %7.2f epochs. See logs for %s',
            step, step / len(train_set), c.log_name)
        logging.info(
            'Saving current state. Please wait!\nBest model has reward %7.2f form step %7d is %s'
            % stopper.highest_reward())
        stopper.saver.save(sess=sess,
                           save_path='%s-FINAL-%.4f-step-%07d' %
                           (stopper.saver_prefix, stopper_reward, step))
def main(c):
    ''' params:
            c: config dictionary
    '''

    # Data ---------------------------------------------------------------------------------------------------
    data_portion = None  # 2 * batch_size
    train_set = Dstc2('data/dstc2/data.dstc2.train.json', sample_unk=0.01, first_n=data_portion)
    valid_set = Dstc2('data/dstc2/data.dstc2.dev.json', first_n=data_portion, sample_unk=0,
                      max_dial_len=train_set.max_dial_len, words_vocab=train_set.words_vocab,
                      labels_vocab=train_set.labels_vocab, labels_vocab_separate=train_set.labels_vocab_separate)
    test_set = Dstc2('data/dstc2/data.dstc2.test.json', first_n=data_portion, sample_unk=0,
                     max_dial_len=train_set.max_dial_len, words_vocab=train_set.words_vocab,
                     labels_vocab=train_set.labels_vocab, labels_vocab_separate=train_set.labels_vocab_separate)

    stats(train_set, valid_set, test_set)

    vocab_size = len(train_set.words_vocab)
    output_dim = max(np.unique(train_set.labels)) + 1
    n_train_batches = len(train_set.dialogs) // c.batch_size

    # output dimensions for each separate label
    output_dims = []
    for i in range(3):
        o_d = max(np.unique(train_set.labels_separate[:,:,i])) + 1
        output_dims.append(o_d)


    # Model -----------------------------------------------------------------------------------------------------
    logging.info('Creating model')
    input_bt = tf.placeholder('int32', [c.batch_size, train_set.max_turn_len], name='input')
    turn_lens_b = tf.placeholder('int32', [c.batch_size], name='turn_lens')
    mask_b = tf.placeholder('int32', [c.batch_size], name='dial_mask')
    # labels_b = tf.placeholder('int64', [c.batch_size], name='labels')
    # onehot_labels_bo = tf.one_hot(indices=labels_b,
    #                               depth=output_dim,
    #                               on_value=1.0,
    #                               off_value=0.0,
    #                               axis=-1)

    # separate labels and their onehots
    labels0_b, onehot_labels0_bo0 = get_labels_with_onehot(c.batch_size, output_dims[0], 'labels0')
    labels1_b, onehot_labels1_bo1 = get_labels_with_onehot(c.batch_size, output_dims[1], 'labels1')
    labels2_b, onehot_labels2_bo2 = get_labels_with_onehot(c.batch_size, output_dims[2], 'labels2')

    is_first_turn = tf.placeholder(tf.bool)
    gru = GRUCell(c.hidden_state_dim)

    embeddings_we = tf.get_variable('word_embeddings',
                                    initializer=tf.random_uniform([vocab_size, c.embedding_dim], -1.0, 1.0))
    embedded_input_bte = tf.nn.embedding_lookup(embeddings_we, input_bt)
    dialog_state_before_turn = tf.get_variable('dialog_state_before_turn',
                                               initializer=tf.zeros([c.batch_size, c.hidden_state_dim], dtype='float32'),
                                               trainable=False)

    before_state_bh = cond(is_first_turn,
                           lambda: gru.zero_state(c.batch_size, dtype='float32'),
                           lambda: dialog_state_before_turn)

    inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in tf.split(1, train_set.max_turn_len, embedded_input_bte)]

    outputs, state_bh = tf.nn.rnn(cell=gru,
                                  inputs=inputs,
                                  initial_state=before_state_bh,
                                  sequence_length=turn_lens_b,
                                  dtype=tf.float32)

    dialog_state_before_turn.assign(state_bh)


    # projection_ho = tf.get_variable('project2labels',
    #                                 initializer=tf.random_uniform([c.hidden_state_dim, output_dim], -1.0, 1.0))



    # logits_bo = tf.matmul(state_bh, projection_ho)
    # tf.histogram_summary('logits', logits_bo)

    # probabilities_bo = tf.nn.softmax(logits_bo)
    # tf.histogram_summary('probabilities', probabilities_bo)


    # logits and probabilites and predictions from hidden state
    logits_bo0, probabilities_bo0, predict_b0 = get_logits_and_probabilities(state_bh, c.hidden_state_dim, output_dims[0], 'labels0')
    logits_bo1, probabilities_bo1, predict_b1 = get_logits_and_probabilities(state_bh, c.hidden_state_dim, output_dims[1], 'labels1')
    logits_bo2, probabilities_bo2, predict_b2 = get_logits_and_probabilities(state_bh, c.hidden_state_dim, output_dims[2], 'labels2')



    float_mask_b = tf.cast(mask_b, 'float32')
    
    # loss = tf.reduce_sum(tf.mul(float_mask_b, x_entropy(logits_bo, onehot_labels_bo))) / tf.reduce_sum(float_mask_b)
    # tf.scalar_summary('CCE loss', loss)

    # losses
    loss_0 = tf.reduce_sum(tf.mul(float_mask_b, x_entropy(logits_bo0, onehot_labels0_bo0))) / tf.reduce_sum(float_mask_b)
    loss_1 = tf.reduce_sum(tf.mul(float_mask_b, x_entropy(logits_bo1, onehot_labels1_bo1))) / tf.reduce_sum(float_mask_b)
    loss_2 = tf.reduce_sum(tf.mul(float_mask_b, x_entropy(logits_bo2, onehot_labels2_bo2))) / tf.reduce_sum(float_mask_b)
    loss = loss_0 + loss_1 + loss_2
    tf.scalar_summary('CCE loss', loss)


    # predict_b = tf.argmax(logits_bo, 1)
    # correct = tf.cast(tf.equal(predict_b, labels_b), 'float32')
    # accuracy = tf.reduce_sum(tf.mul(correct, float_mask_b)) / tf.reduce_sum(float_mask_b)
    # tf.scalar_summary('Accuracy', accuracy)


    # correct
    correct_0 = tf.cast(tf.equal(predict_b0, labels0_b), 'float32')
    correct_1 = tf.cast(tf.equal(predict_b1, labels1_b), 'float32')
    correct_2 = tf.cast(tf.equal(predict_b2, labels2_b), 'float32')
    correct_all = tf.mul(tf.mul(correct_0, correct_1), correct_2)

    # accuracies
    accuracy_0 = get_accuracy(correct_0, float_mask_b)
    accuracy_1 = get_accuracy(correct_1, float_mask_b)
    accuracy_2 = get_accuracy(correct_2, float_mask_b)
    accuracy_all = get_accuracy(correct_all, float_mask_b)
    tf.scalar_summary('Accuracy all', accuracy_all)    
    tf.scalar_summary('Accuracy label 0', accuracy_0)
    tf.scalar_summary('Accuracy label 1', accuracy_1)
    tf.scalar_summary('Accuracy label 2', accuracy_2)



    tb_info = tf.merge_all_summaries()

    # Optimizer  -----------------------------------------------------------------------------------------------------
    logging.info('Creating optimizer')
    optimizer = tf.train.AdamOptimizer(c.learning_rate)
    logging.info('Creating train_op')
    train_op = optimizer.minimize(loss)

    # Session  -----------------------------------------------------------------------------------------------------
    logging.info('Creating session')
    sess = tf.Session()
    logging.info('Initing variables')
    init = tf.initialize_all_variables()
    logging.info('Running session')
    sess.run(init)

    # TB ---------------------------------------------------------------------------------------------------------
    logging.info('See stats via tensorboard: $ tensorboard --logdir %s', c.log_dir)
    train_writer = tf.train.SummaryWriter(c.log_dir, sess.graph)

    # Train ---------------------------------------------------------------------------------------------------------
    train_summary = None
    step, stopper = 0, EarlyStopper(c.nbest_models, c.not_change_limit, c.name)
    try:
        for e in range(c.epochs):
            logging.info('------------------------------')
            logging.info('Epoch %d', e)

            total_loss = 0
            total_acc = 0
            batch_count = 0
            for bid, (dialogs_bTt, lengths_bT, labels0_bT, labels1_bT, labels2_bT, masks_bT) in enumerate(next_batch(train_set, c.batch_size)):
                turn_loss = 0
                turn_acc = 0
                n_turns = 0
                first_run = True
                for (turn_bt, label0_b, label1_b, label2_b, lengths_b, masks_b) in zip(dialogs_bTt.transpose([1, 0, 2]),
                                                                  labels0_bT.transpose([1, 0]),
                                                                  labels1_bT.transpose([1, 0]),
                                                                  labels2_bT.transpose([1, 0]),
                                                                  lengths_bT.transpose([1, 0]),
                                                                  masks_bT.transpose([1,0])):
                    if sum(masks_b) == 0:
                        break 

                    _, batch_loss, batch_accuracy, train_summary = sess.run([train_op, loss, accuracy_all, tb_info],
                                                                            feed_dict={input_bt: turn_bt,
                                                                                       turn_lens_b: lengths_b,
                                                                                       mask_b: masks_b,
                                                                                       labels0_b: label0_b,
                                                                                       labels1_b: label1_b,
                                                                                       labels2_b: label2_b,
                                                                                       is_first_turn: first_run})
                    first_run = False
                    turn_loss += batch_loss
                    turn_acc += batch_accuracy
                    n_turns += 1
                    step += 1

                total_loss += turn_loss / n_turns
                total_acc += turn_acc / n_turns
                batch_count += 1

                logging.info('Batch %d/%d\r', bid, n_train_batches)

            train_writer.add_summary(train_summary, e)
            logging.info('Train cost %f', total_loss / batch_count)
            logging.info('Train accuracy: %f', total_acc / batch_count)

            def monitor_stream(work_set, name):
                total_loss = 0
                total_acc = 0
                n_valid_batches = 0
                for bid, (dialogs_bTt, lengths_bT, labels0_bT, labels1_bT, labels2_bT, masks_bT) in enumerate(next_batch(work_set, c.batch_size)):
                    turn_loss = 0
                    turn_acc = 0
                    n_turns = 0
                    first_run = True
                    for (turn_bt, label0_b, label1_b, label2_b, lengths_b, masks_b) in zip(dialogs_bTt.transpose([1, 0, 2]),
                                                                  labels0_bT.transpose([1, 0]),
                                                                  labels1_bT.transpose([1, 0]),
                                                                  labels2_bT.transpose([1, 0]),
                                                                  lengths_bT.transpose([1, 0]),
                                                                  masks_bT.transpose([1,0])):
                        if sum(masks_b) == 0:
                            break

                        input = np.pad(turn_bt, ((0, 0), (0, train_set.max_turn_len-turn_bt.shape[1])),
                                       'constant', constant_values=0) if train_set.max_turn_len > turn_bt.shape[1]\
                            else turn_bt

                        batch_loss, batch_acc, valid_summary = sess.run([loss, accuracy_all, tb_info],
                                                                                     feed_dict={input_bt: input,
                                                                                                turn_lens_b: lengths_b,
                                                                                               labels0_b: label0_b,
                                                                                               labels1_b: label1_b,
                                                                                               labels2_b: label2_b,
                                                                                                mask_b: masks_b,
                                                                                                is_first_turn: first_run})
                        turn_loss += batch_loss
                        turn_acc += batch_acc
                        first_run = False
                        n_turns += 1

                    total_loss += turn_loss / n_turns
                    total_acc += turn_acc / n_turns
                    n_valid_batches += 1

                logging.info('%s cost: %f', name, total_loss/n_valid_batches)
                logging.info('%s accuracy: %f', name, total_acc/n_valid_batches)

                return  total_loss/n_valid_batches

            stopper_reward = monitor_stream(valid_set, 'Valid')
            monitor_stream(test_set, 'Test')
            if not stopper.save_and_check(stopper_reward, step, sess):
                raise RuntimeError('Training not improving on dev set')
    finally:
        logging.info('Training stopped after %7d steps and %7.2f epochs. See logs for %s', step, step / len(train_set), c.log_name)
        logging.info('Saving current state. Please wait!\nBest model has reward %7.2f form step %7d is %s' % stopper.highest_reward())
        stopper.saver.save(sess=sess, save_path='%s-FINAL-%.4f-step-%07d' % (stopper.saver_prefix, stopper_reward, step))