import argparse

from config import config

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run Copyright ML Flows')
    parser.add_argument('run_type')
    parser.add_argument('--config', '-c', type=str, required=False)
    parser.add_argument('--background_directory',
                        '-b',
                        type=str,
                        required=True)
    parser.add_argument('--evaluation_directory',
                        '-e',
                        type=str,
                        required=True)
    parser.add_argument('--model_path', '-m', type=str, required=True)

    args = parser.parse_args()
    config.load_json(args.config)

    from model import SiameseModel

    siamese_model = SiameseModel()

    siamese_model.train(args.background_directory, args.evaluation_directory,
                        args.model_path)
Example #2
0
def main():
    args = setup_args()
    hparams = build_hparams(args)
    logging.info(hparams)

    #Create Valid graph, and session
    valid_graph = tf.Graph()

    with valid_graph.as_default():
        # Set random seed
        tf.set_random_seed(args.seed)
        vocab_table = lookup_ops.index_table_from_file(hparams.vocab,
                                                       default_value=0)
        if hparams.train_context:
            valid_iterator = create_labeled_data_iterator_with_context(
                hparams.valid_context, hparams.valid_txt1, hparams.valid_txt2,
                hparams.valid_labels, vocab_table, hparams.size_valid_batch)
        else:

            valid_iterator = create_labeled_data_iterator(
                hparams.valid_txt1, hparams.valid_txt2, hparams.valid_labels,
                vocab_table, hparams.size_valid_batch)

        valid_model = SiameseModel(hparams, valid_iterator, ModeKeys.EVAL)

        #Create Training session and init its variables, tables and iterator.
        valid_sess = tf.Session()
        valid_sess.run(valid_iterator.init)

        valid_sess.run(tf.global_variables_initializer())
        valid_sess.run(tf.tables_initializer())

        eval_loss, time_taken, _ = valid_model.eval(valid_sess)
        logging.info('Init Val Loss: %.4f Time: %ds' % (eval_loss, time_taken))

    #Create Model dir if required
    if not tf.gfile.Exists(hparams.model_dir):
        logging.info('Creating Model dir: %s' % hparams.model_dir)
        tf.gfile.MkDir(hparams.model_dir)
    save_hparams(hparams)

    #Create Training graph, and session
    train_graph = tf.Graph()

    with train_graph.as_default():
        # Set random seed
        tf.set_random_seed(args.seed)

        #First word in vocab file is UNK (see prep_data/create_vocab.py)
        vocab_table = lookup_ops.index_table_from_file(hparams.vocab,
                                                       default_value=0)

        if hparams.train_context:
            train_iterator = create_labeled_data_iterator_with_context(
                hparams.train_context, hparams.train_txt1, hparams.train_txt2,
                hparams.train_labels, vocab_table, hparams.size_train_batch)
        else:
            train_iterator = create_labeled_data_iterator(
                hparams.train_txt1, hparams.train_txt2, hparams.train_labels,
                vocab_table, hparams.size_train_batch)

        train_model = SiameseModel(hparams, train_iterator, ModeKeys.TRAIN)

        #Create Training session and init its variables, tables and iterator.
        train_sess = tf.Session()
        train_sess.run(tf.global_variables_initializer())
        train_sess.run(tf.tables_initializer())
        train_sess.run(train_iterator.init)

    #Training loop
    summary_writer = tf.summary.FileWriter(
        os.path.join(hparams.model_dir, 'train_log'))
    epoch_num = 0
    epoch_start_time = time.time()
    best_eval_loss = 100.0

    #When did we last check validation data
    last_eval_step = 0

    #When did we last save training stats and checkoiint
    last_stats_step = 0

    train_saver_path = os.path.join(hparams.model_dir, 'sm')
    valid_saver_path = os.path.join(hparams.model_dir, 'best_eval')
    tf.gfile.MakeDirs(valid_saver_path)
    valid_saver_path = os.path.join(valid_saver_path, 'sm')

    for step in itertools.count():
        try:
            _, loss, train_summary = train_model.train(train_sess)

            #Steps per stats
            if step - last_stats_step >= hparams.steps_per_stats:
                logging.info('Epoch: %d Step %d: Train_Loss: %.4f' %
                             (epoch_num, step, loss))
                train_model.saver.save(train_sess, train_saver_path, step)
                summary_writer.add_summary(train_summary, step)
                last_stats_step = step

            # Eval model and print stats
            if step - last_eval_step >= hparams.steps_per_eval:
                latest_ckpt = tf.train.latest_checkpoint(hparams.model_dir)
                valid_model.saver.restore(valid_sess, latest_ckpt)
                eval_loss, time_taken, eval_summary = valid_model.eval(
                    valid_sess)
                summary_writer.add_summary(eval_summary, step)

                if eval_loss < best_eval_loss:
                    valid_model.saver.save(valid_sess, valid_saver_path, step)
                    logging.info(
                        'Epoch: %d Step: %d Valid_Loss Improved New: %.4f Old: %.4f'
                        % (epoch_num, step, eval_loss, best_eval_loss))
                    best_eval_loss = eval_loss
                else:
                    logging.info(
                        'Epoch: %d Step: %d Valid_Loss Worse New: %.4f Old: %.4f'
                        % (epoch_num, step, eval_loss, best_eval_loss))
                last_eval_step = step

        except tf.errors.OutOfRangeError:
            logging.info('Epoch %d END Time: %ds' %
                         (epoch_num, time.time() - epoch_start_time))
            epoch_num += 1

            with train_graph.as_default():
                train_sess.run(train_iterator.init)
            epoch_start_time = time.time()