import argparse from config import config if __name__ == "__main__": parser = argparse.ArgumentParser(description='Run Copyright ML Flows') parser.add_argument('run_type') parser.add_argument('--config', '-c', type=str, required=False) parser.add_argument('--background_directory', '-b', type=str, required=True) parser.add_argument('--evaluation_directory', '-e', type=str, required=True) parser.add_argument('--model_path', '-m', type=str, required=True) args = parser.parse_args() config.load_json(args.config) from model import SiameseModel siamese_model = SiameseModel() siamese_model.train(args.background_directory, args.evaluation_directory, args.model_path)
def main(): args = setup_args() hparams = build_hparams(args) logging.info(hparams) #Create Valid graph, and session valid_graph = tf.Graph() with valid_graph.as_default(): # Set random seed tf.set_random_seed(args.seed) vocab_table = lookup_ops.index_table_from_file(hparams.vocab, default_value=0) if hparams.train_context: valid_iterator = create_labeled_data_iterator_with_context( hparams.valid_context, hparams.valid_txt1, hparams.valid_txt2, hparams.valid_labels, vocab_table, hparams.size_valid_batch) else: valid_iterator = create_labeled_data_iterator( hparams.valid_txt1, hparams.valid_txt2, hparams.valid_labels, vocab_table, hparams.size_valid_batch) valid_model = SiameseModel(hparams, valid_iterator, ModeKeys.EVAL) #Create Training session and init its variables, tables and iterator. valid_sess = tf.Session() valid_sess.run(valid_iterator.init) valid_sess.run(tf.global_variables_initializer()) valid_sess.run(tf.tables_initializer()) eval_loss, time_taken, _ = valid_model.eval(valid_sess) logging.info('Init Val Loss: %.4f Time: %ds' % (eval_loss, time_taken)) #Create Model dir if required if not tf.gfile.Exists(hparams.model_dir): logging.info('Creating Model dir: %s' % hparams.model_dir) tf.gfile.MkDir(hparams.model_dir) save_hparams(hparams) #Create Training graph, and session train_graph = tf.Graph() with train_graph.as_default(): # Set random seed tf.set_random_seed(args.seed) #First word in vocab file is UNK (see prep_data/create_vocab.py) vocab_table = lookup_ops.index_table_from_file(hparams.vocab, default_value=0) if hparams.train_context: train_iterator = create_labeled_data_iterator_with_context( hparams.train_context, hparams.train_txt1, hparams.train_txt2, hparams.train_labels, vocab_table, hparams.size_train_batch) else: train_iterator = create_labeled_data_iterator( hparams.train_txt1, hparams.train_txt2, hparams.train_labels, vocab_table, hparams.size_train_batch) train_model = SiameseModel(hparams, train_iterator, ModeKeys.TRAIN) #Create Training session and init its variables, tables and iterator. train_sess = tf.Session() train_sess.run(tf.global_variables_initializer()) train_sess.run(tf.tables_initializer()) train_sess.run(train_iterator.init) #Training loop summary_writer = tf.summary.FileWriter( os.path.join(hparams.model_dir, 'train_log')) epoch_num = 0 epoch_start_time = time.time() best_eval_loss = 100.0 #When did we last check validation data last_eval_step = 0 #When did we last save training stats and checkoiint last_stats_step = 0 train_saver_path = os.path.join(hparams.model_dir, 'sm') valid_saver_path = os.path.join(hparams.model_dir, 'best_eval') tf.gfile.MakeDirs(valid_saver_path) valid_saver_path = os.path.join(valid_saver_path, 'sm') for step in itertools.count(): try: _, loss, train_summary = train_model.train(train_sess) #Steps per stats if step - last_stats_step >= hparams.steps_per_stats: logging.info('Epoch: %d Step %d: Train_Loss: %.4f' % (epoch_num, step, loss)) train_model.saver.save(train_sess, train_saver_path, step) summary_writer.add_summary(train_summary, step) last_stats_step = step # Eval model and print stats if step - last_eval_step >= hparams.steps_per_eval: latest_ckpt = tf.train.latest_checkpoint(hparams.model_dir) valid_model.saver.restore(valid_sess, latest_ckpt) eval_loss, time_taken, eval_summary = valid_model.eval( valid_sess) summary_writer.add_summary(eval_summary, step) if eval_loss < best_eval_loss: valid_model.saver.save(valid_sess, valid_saver_path, step) logging.info( 'Epoch: %d Step: %d Valid_Loss Improved New: %.4f Old: %.4f' % (epoch_num, step, eval_loss, best_eval_loss)) best_eval_loss = eval_loss else: logging.info( 'Epoch: %d Step: %d Valid_Loss Worse New: %.4f Old: %.4f' % (epoch_num, step, eval_loss, best_eval_loss)) last_eval_step = step except tf.errors.OutOfRangeError: logging.info('Epoch %d END Time: %ds' % (epoch_num, time.time() - epoch_start_time)) epoch_num += 1 with train_graph.as_default(): train_sess.run(train_iterator.init) epoch_start_time = time.time()