def autoencode_procedure(option, path, z_size, examples, fold_inputs,
                         encoder_graph, decoder_graph, mus_and_log_sigs,
                         look_behind, token_probs_t, hidden_state_t, sess):
    # Autoencode_bit
    examples_dir = join(BASEDIR, option + '_examples')
    mkdir_p(examples_dir)

    autoencoded_examples_path = join(examples_dir, 'autoencoded')
    message = 'Autencoding {} examples'.format(len(examples))
    for i, (fold_batch, input_sequence) in enumerate(
            tqdm(zip(td.group_by_batches(fold_inputs, 1), examples),
                 desc=message,
                 total=len(examples))):
        dir_for_example = join(autoencoded_examples_path, str(i))
        mkdir_p(dir_for_example)

        p_of_z = sess.run(
            mus_and_log_sigs,
            feed_dict={encoder_graph.loom_input_tensor: fold_batch})
        mus = p_of_z[:, :z_size]

        input_code = example_to_code(input_sequence[look_behind:])
        write_to_file(join(dir_for_example, 'input.hs'), input_code)
        imsave(join(dir_for_example, 'input.png'), input_sequence.T)
        imsave(join(dir_for_example, 'z.png'), mus.reshape(
            (mus.size // 32, 32)))

        # DECODER PART
        if look_behind > 0:
            decoder_input = padded_look_behind_generator(
                input_sequence, look_behind, TOKEN_EMB_SIZE)
        else:
            decoder_input = zeros_generator()  # NOQA

        fd = decoder_graph.build_feed_dict([(mus.squeeze(axis=0),
                                             [next(decoder_input)])])

        token_probs, hidden_state = sess.run([token_probs_t, hidden_state_t],
                                             feed_dict=fd)

        tokens_probs = [token_probs]
        text_so_far = [one_hot(TOKEN_EMB_SIZE, token_probs.squeeze().argmax())]
        for _ in range(len(input_sequence) - 1):
            fd = decoder_graph.build_feed_dict([(hidden_state.squeeze(),
                                                 [next(decoder_input)])])

            token_probs, hidden_state = sess.run(
                [token_probs_t, hidden_state_t], feed_dict=fd)
            tokens_probs += [token_probs]
            text_so_far += [
                one_hot(TOKEN_EMB_SIZE,
                        token_probs.squeeze().argmax())
            ]

        decoder_output = np.concatenate(tokens_probs)
        imsave(join(dir_for_example, 'decoder_output.png'), decoder_output.T)
        write_to_file(join(dir_for_example, 'autoencoded_code.hs'),
                      example_to_code(text_so_far))
Пример #2
0
 def train_epoch(self,sess, train_set):
     if not self.is_compiled:
         return 0, 0
     t = time.time()
     print("starttrainepoch")
     loss = sum(self.train_step(sess,ba)
                for ba in td.group_by_batches(train_set, self.BATCH_SIZE))
     t = time.time() - t
     return loss, t
Пример #3
0
 def train_epoch(self, train_set, batch_size):
     loss = 0
     for batch in td.group_by_batches(train_set, batch_size):
         train_feed_dict = {
             self.keep_prob_ph: self.ModelConfig.keep_prob,
             self.compiler.loom_input_tensor: batch
         }
         loss += self.train_step(train_feed_dict)
     return loss
Пример #4
0
def test_eval():
    test_loss = tf.reduce_sum(compiler.metric_tensors['root_loss'])
    _test_logits = compiler.metric_tensors['root_logits']
    test_loss_whole = 0.
    test_pred_whole = []
    test_labels_whole = []
    test_logits_whole = []
    # f1 = open("logTmp", "w")
    for batch in td.group_by_batches(test_set, BATCH_SIZE):
        test_feed_dict[compiler.loom_input_tensor] = batch
        test_loss_batch, test_pred_batch, test_labels_batch, test_logits_batch = sess.run(
            [test_loss, pred, labels, _test_logits], test_feed_dict)
        test_loss_whole = test_loss_whole + test_loss_batch
        test_pred_whole = test_pred_whole + test_pred_batch.tolist()
        test_labels_whole = test_labels_whole + test_labels_batch.tolist()
        test_logits_whole = test_logits_whole + test_logits_batch.tolist()
    f1score = metrics.f1_score(test_labels_whole,
                               test_pred_whole,
                               average=None)
    print('test_loss_avg: %.3e, test_f1score:\n  [%s]' %
          (test_loss_whole, f1score))
    return test_labels_whole, test_pred_whole, test_logits_whole
Пример #5
0
def train_epoch(train_set):
    return sum(
        train_step(batch)
        for batch in td.group_by_batches(train_set, BATCH_SIZE))
Пример #6
0
    def get_batches(self, dataset=None, batch_size=100):

        if dataset == None:
            dataset = self.get_dataset()
        import tensorflow_fold as td
        return td.group_by_batches(dataset, batch_size)
Пример #7
0
def main():
    apputil.initialize(variable_scope='embedding')

    # load data early so we can initialize hyper parameters accordingly
    ds = data.load_dataset('../data/statements')
    hyper.node_type_num = len(ds.word2int)

    hyper.dump()

    # create model variables
    param.initialize_embedding_weights()

    # Compile the block
    tree_sum = td.GetItem(0) >> tree_sum_blk(l2loss_blk)
    compiler = td.Compiler.create(tree_sum)
    (batched_loss, ) = compiler.output_tensors
    loss = tf.reduce_mean(batched_loss)
    opt = tf.train.AdamOptimizer(learning_rate=hyper.learning_rate)

    global_step = tf.Variable(0, trainable=False, name='global_step')
    train_step = opt.minimize(loss, global_step=global_step)

    # Attach summaries
    tf.summary.histogram('Wl', param.get('Wl'))
    tf.summary.histogram('Wr', param.get('Wr'))
    tf.summary.histogram('B', param.get('B'))
    tf.summary.histogram('Embedding', param.get('We'))
    tf.summary.scalar('loss', loss)

    summary_op = tf.summary.merge_all()

    # create missing dir
    if not os.path.exists(hyper.train_dir):
        os.makedirs(hyper.train_dir)

    # train loop
    saver = tf.train.Saver()
    train_set = compiler.build_loom_inputs(ds.get_split('all')[1])
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(hyper.log_dir, graph=sess.graph)
        write_embedding_metadata(summary_writer, ds.word2int)

        for epoch, shuffled in enumerate(
                td.epochs(train_set, hyper.num_epochs), 1):
            for step, batch in enumerate(
                    td.group_by_batches(shuffled, hyper.batch_size), 1):
                train_feed_dict = {compiler.loom_input_tensor: batch}

                start_time = default_timer()
                _, loss_value, summary, gstep = sess.run(
                    [train_step, loss, summary_op, global_step],
                    train_feed_dict)
                duration = default_timer() - start_time

                logger.info(
                    'global %d epoch %d step %d loss = %.2f (%.1f samples/sec; %.3f sec/batch)',
                    gstep, epoch, step, loss_value,
                    hyper.batch_size / duration, duration)
                if gstep % 10 == 0:
                    summary_writer.add_summary(summary, gstep)
                if gstep % 10 == 0:
                    saver.save(sess,
                               os.path.join(hyper.train_dir, "model.ckpt"),
                               global_step=gstep)
Пример #8
0
def do_evaluation():
    # load data early to get node_type_num
    ds = data.load_dataset('data/statements')
    hyper.node_type_num = len(ds.word2int)

    (compiler, _, _, _, raw_accuracy, batch_size_op) = build_model()

    # restorer for embedding matrix
    embedding_path = tf.train.latest_checkpoint(hyper.embedding_dir)
    if embedding_path is None:
        raise ValueError('Path to embedding checkpoint is incorrect: ' +
                         hyper.embedding_dir)

    # restorer for other variables
    checkpoint_path = tf.train.latest_checkpoint(hyper.train_dir)
    if checkpoint_path is None:
        raise ValueError('Path to tbcnn checkpoint is incorrect: ' +
                         hyper.train_dir)

    restored_vars = tf.get_collection_ref('restored')
    restored_vars.append(param.get('We'))
    restored_vars.extend(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
    embeddingRestorer = tf.train.Saver({'embedding/We': param.get('We')})
    restorer = tf.train.Saver(
        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

    # train loop
    total_size, test_gen = ds.get_split('test')
    test_set = compiler.build_loom_inputs(test_gen)
    with tf.Session() as sess:
        # Restore embedding matrix first
        embeddingRestorer.restore(sess, embedding_path)
        # Restore others
        restorer.restore(sess, checkpoint_path)
        # Initialize other variables
        gvariables = [
            v for v in tf.global_variables()
            if v not in tf.get_collection('restored')
        ]
        sess.run(tf.variables_initializer(gvariables))

        num_epochs = 1 if not hyper.warm_up else 3
        for shuffled in td.epochs(test_set, num_epochs):
            logger.info('')
            logger.info(
                '======================= Evaluation ===================================='
            )
            accumulated_accuracy = 0.
            start_time = default_timer()
            for step, batch in enumerate(
                    td.group_by_batches(shuffled, hyper.batch_size), 1):
                feed_dict = {compiler.loom_input_tensor: batch}
                accuracy_value, actual_bsize = sess.run(
                    [raw_accuracy, batch_size_op], feed_dict)
                accumulated_accuracy += accuracy_value * actual_bsize
                logger.info(
                    'evaluation in progress: running accuracy = %.2f, processed = %d / %d',
                    accuracy_value,
                    (step - 1) * hyper.batch_size + actual_bsize, total_size)
            duration = default_timer() - start_time
            total_accuracy = accumulated_accuracy / total_size
            logger.info(
                'evaluation accumulated accuracy = %.2f%% (%.1f samples/sec; %.2f seconds)',
                total_accuracy * 100, total_size / duration, duration)
            logger.info(
                '======================= Evaluation End ================================='
            )
            logger.info('')
Пример #9
0
def do_train():
    # load data early to get node_type_num
    ds = data.load_dataset('data/statements')
    hyper.node_type_num = len(ds.word2int)

    hyper.dump()

    (compiler, unscaled_logits, logits, batched_labels, raw_accuracy,
     batch_size_op) = build_model()

    (loss, global_step, train_step,
     summary_op) = train_with_val(unscaled_logits, batched_labels,
                                  raw_accuracy)

    val_summary_op = tf.summary.scalar('val_accuracy', raw_accuracy)

    # create missing dir
    if not os.path.exists(hyper.train_dir):
        os.makedirs(hyper.train_dir)

    # restorer for embedding matrix
    restorer = tf.train.Saver({'embedding/We': param.get('We')})
    embedding_path = tf.train.latest_checkpoint(hyper.embedding_dir)
    if embedding_path is None:
        raise ValueError('Path to embedding checkpoint is incorrect: ' +
                         hyper.embedding_dir)

    # train loop
    saver = tf.train.Saver()
    train_set = compiler.build_loom_inputs(ds.get_split('train')[1])
    val_set = compiler.build_loom_inputs(ds.get_split('val')[1])
    with tf.Session() as sess:
        # Restore embedding matrix first
        restorer.restore(sess, embedding_path)
        # Initialize other variables
        gvariables = tf.global_variables()
        gvariables.remove(param.get('We'))  # exclude We
        sess.run(tf.variables_initializer(gvariables))

        summary_writer = tf.summary.FileWriter(hyper.log_dir, graph=sess.graph)

        val_step_counter = 0
        shuffled = zip(td.epochs(train_set, hyper.num_epochs),
                       td.epochs(val_set, hyper.num_epochs))
        for epoch, (train_shuffled, val_shuffled) in enumerate(shuffled, 1):
            for step, batch in enumerate(
                    td.group_by_batches(train_shuffled, hyper.batch_size), 1):
                train_feed_dict = {compiler.loom_input_tensor: batch}

                start_time = default_timer()
                (_, loss_value, summary, gstep, actual_bsize) = sess.run(
                    [train_step, loss, summary_op, global_step, batch_size_op],
                    train_feed_dict)
                duration = default_timer() - start_time

                logger.info(
                    'global %d epoch %d step %d loss = %.2f (%.1f samples/sec; %.3f sec/batch)',
                    gstep, epoch, step, loss_value, actual_bsize / duration,
                    duration)
                if gstep % 10 == 0:
                    summary_writer.add_summary(summary, gstep)

            # do a validation test
            logger.info('')
            logger.info(
                '======================= Validation ===================================='
            )
            accumulated_accuracy = 0.
            total_size = 0
            start_time = default_timer()
            for batch in td.group_by_batches(val_shuffled, hyper.batch_size):
                feed_dict = {compiler.loom_input_tensor: batch}
                accuracy_value, actual_bsize, val_summary = sess.run(
                    [raw_accuracy, batch_size_op, val_summary_op], feed_dict)
                summary_writer.add_summary(val_summary, val_step_counter)
                accumulated_accuracy += accuracy_value * actual_bsize
                total_size += actual_bsize
                val_step_counter += 1
                logger.info(
                    'validation step, accuracy = %.2f, current batch = %d, processed = %d',
                    accuracy_value, actual_bsize, total_size)
            duration = default_timer() - start_time
            total_accuracy = accumulated_accuracy / total_size
            logger.info(
                'validation acc = %.2f%% (%.1f samples/sec; %.2f seconds)',
                total_accuracy * 100, total_size / duration, duration)
            saved_path = saver.save(sess,
                                    os.path.join(hyper.train_dir,
                                                 "model.ckpt"),
                                    global_step=gstep)
            logger.info('validation saved path: %s', saved_path)
            logger.info(
                '======================= Validation End ================================='
            )
            logger.info('')