def main(_): print("-" * 80) if not os.path.isdir(FLAGS.output_dir): print("Path {} does not exist. Creating.".format(FLAGS.output_dir)) os.makedirs(FLAGS.output_dir) elif FLAGS.reset_output_dir: print("Path {} exists. Remove and remake.".format(FLAGS.output_dir)) shutil.rmtree(FLAGS.output_dir) os.makedirs(FLAGS.output_dir) print("-" * 80) log_file = os.path.join(FLAGS.output_dir, "stdout") print("Logging to {}".format(log_file)) sys.stdout = Logger(log_file) print_user_flags() train()
def main(_): logger.info(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) logger.info("-" * 80) # if not os.path.isdir(FLAGS.output_dir): # logger.info("Path {} does not exist. Creating.".format(FLAGS.output_dir)) # os.makedirs(FLAGS.output_dir) # elif FLAGS.reset_output_dir: # logger.info("Path {} exists. Remove and remake.".format(FLAGS.output_dir)) # shutil.rmtree(FLAGS.output_dir) # os.makedirs(FLAGS.output_dir) # logger.info("-" * 80) # log_file = os.path.join(FLAGS.output_dir, "stdout") # logger.info("Logging to {}".format(log_file)) # sys.stdout = Logger(log_file) utils.print_user_flags() train() logger.info('End.') logger.info(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
def train(): """ Train driver - no need to change. """ print("-" * 80) if not os.path.isdir(FLAGS.output_dir): print("Path {0} does not exist. Creating.".format(FLAGS.output_dir)) os.makedirs(FLAGS.output_dir) elif FLAGS.reset_output_dir: print("Path {0} exists. Remove and remake.".format(FLAGS.output_dir)) shutil.rmtree(FLAGS.output_dir) os.makedirs(FLAGS.output_dir) print_user_flags() # -------------------------------- # DATASET # -------------------------------- data_dict = load_dataset(FLAGS.data_path, is_toy=FLAGS.is_toy, num_examples=FLAGS.n_loaded_sentences) batched_datasets = create_batched_dataset(data_dict, batch_size=FLAGS.batch_size) # -------------------------------- # MODEL # -------------------------------- nmt_model = model(data_dict, checkpoint_dir=FLAGS.output_dir, emb_dim=FLAGS.emb_dim, n_hidden=FLAGS.n_hidden, type=FLAGS.cell_type) encoder, decoder = nmt_model['encoder'], nmt_model['decoder'] # -------------------------------- # RESTORE # -------------------------------- # try to restore saved model if FLAGS.restore_checkpoint: try: nmt_model['train_ckpt']. \ restore(tf.train.latest_checkpoint(FLAGS.output_dir)) print("\n\nTRYING TO RESTORE THE LATEST MODEL FROM: " "'{}'\n\n".format(FLAGS.output_dir)) except Exception as e: print("Error in restoring checkpoint: {}. Training from " "scratch".format(e)) # -------------------------------- # TRAIN # -------------------------------- # (continue to) train now for epoch in range(1, FLAGS.n_epochs + 1): start = time.time() total_loss = 0 total_ppl = 0 # reset encoder's hidden state at the beginning of every epoch to ZEROs hidden = tf.zeros((FLAGS.batch_size, FLAGS.n_hidden)) for (batch_num, (src, tgt)) in enumerate(batched_datasets['train']): loss = 0 with tf.GradientTape() as tape: loss = teacher_forcing(data_dict, encoder, decoder, hidden, loss, src, tgt)# forward batch_loss = (loss / int(tgt.shape[1])) batch_ppl = np.exp(batch_loss) total_loss += batch_loss total_ppl += batch_ppl # backward variables = encoder.variables + decoder.variables gradients = tape.gradient(loss, variables) nmt_model['optim'].apply_gradients(zip(gradients, variables)) if batch_num % 100 == 0: print('Epoch {} Batch {} Loss {:.4f} ' 'PPL {:.2f}'.format(epoch, batch_num, batch_loss.numpy(), batch_ppl)) # evaluate some sentences if (batch_num + 1) % FLAGS.eval_every == 0: sentences = draw_random_sentences(os.path.join(FLAGS.data_path, 'val.txt'), size=2) for sentence in sentences: # change False to True for Attention Plots translate(sentence, encoder, decoder, data_dict, FLAGS.is_toy, FLAGS.is_attention, False) # saving (checkpoint) the model every 2 epochs if (batch_num + 1) % FLAGS.save_every == 0: nmt_model['train_ckpt'].save(file_prefix=nmt_model['ckpt_prefix']) print("Model saved to '{}".format(FLAGS.output_dir)) print('Epoch {} Loss {:.4f} PPL {:.2f}'.format(epoch, total_loss / batch_num + 1, total_ppl / batch_num + 1)) print("Saving for this epoch") nmt_model['train_ckpt'].save(file_prefix=nmt_model['ckpt_prefix']) print('Time taken for this epoch {:.2f} sec\n'.format(time.time() - start)) print("-" * 80) sys.stdout.flush() # final test after training # you can inject this into the loop to do it at every epoch's end if epoch == FLAGS.n_epochs: print("Testing the whole test set for epoch {}".format(epoch)) test_lines = open(os.path.join(FLAGS.data_path, 'target.txt')).\ read().strip().split('\n') with open(os.path.join(FLAGS.data_path, 'translated_{}.txt'.format(epoch)), 'w') as f: for line in test_lines: result, _, _ = evaluate(line, encoder, decoder, data_dict, FLAGS.is_toy, FLAGS.is_attention) f.write(' '.join(result.split()[:-1]) + '\n')
def main(_): print_user_flags() tf.logging.set_verbosity(tf.logging.INFO) processors = { "ar": ARProcessor, } if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: raise ValueError( "At least one of `do_train`, `do_eval` or `do_predict' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) tf.gfile.MakeDirs(FLAGS.output_dir) task_name = FLAGS.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() label_list = processor.get_labels() tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) print('before train', time.localtime()) train_examples = None eval_examples = None num_train_steps = None num_warmup_steps = None if FLAGS.do_train: train_examples = processor.get_train_examples(FLAGS.data_dir) # if FLAGS.do_eval: # valsize = int(len(train_examples) / 9) # take 10% as validation # eval_examples = train_examples[:valsize] # train_examples = train_examples[valsize:] num_train_steps = int( len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) model_fn = model_fn_builder( bert_config=bert_config, num_labels=len(label_list), init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_tpu) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. estimator = tf.contrib.tpu.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, predict_batch_size=FLAGS.predict_batch_size) if FLAGS.do_train: train_file = os.path.join(FLAGS.output_dir, "train.tf_record") if not FLAGS.use_record: file_based_convert_examples_to_features(train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) tf.logging.info("***** Running training *****") tf.logging.info(" Num examples = %d", len(train_examples)) tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Num steps = %d", num_train_steps) train_input_fn = file_based_input_fn_builder( input_file=train_file, seq_length=FLAGS.max_seq_length, is_training=True, drop_remainder=True) val_file = os.path.join(FLAGS.output_dir, "val.tf_record") val_examples = processor.get_dev_examples(FLAGS.data_dir) if not FLAGS.use_record: file_based_convert_examples_to_features(val_examples, label_list, FLAGS.max_seq_length, tokenizer, val_file) tf.logging.info(" Num val examples = %d", len(val_examples)) tf.logging.info(" Batch size val = %d", FLAGS.eval_batch_size) val_drop_remainder = True if FLAGS.use_tpu else False val_input_fn = file_based_input_fn_builder( input_file=val_file, seq_length=FLAGS.max_seq_length, is_training=False, drop_remainder=val_drop_remainder) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps) val_spec = tf.estimator.EvalSpec(input_fn=val_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, val_spec) # estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) # result = estimator.evaluate(input_fn=val_input_fn, steps=val_steps) print('before eval', time.localtime()) if FLAGS.do_eval: eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") if eval_examples == None: eval_examples = processor.get_mytest_examples(FLAGS.data_dir) if not FLAGS.use_record: file_based_convert_examples_to_features(eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) tf.logging.info("***** Running evaluation *****") tf.logging.info(" Num examples = %d", len(eval_examples)) tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) # This tells the estimator to run through the entire set. eval_steps = None # However, if running eval on the TPU, you will need to specify the # number of steps. if FLAGS.use_tpu: # Eval will be slightly WRONG on the TPU because it will truncate # the last batch. eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) eval_drop_remainder = True if FLAGS.use_tpu else False eval_input_fn = file_based_input_fn_builder( input_file=eval_file, seq_length=FLAGS.max_seq_length, is_training=False, drop_remainder=eval_drop_remainder) result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") with tf.gfile.GFile(output_eval_file, "w") as writer: tf.logging.info("***** Eval results *****") for key in sorted(result.keys()): tf.logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) print('before predict', time.localtime()) if FLAGS.do_predict: if not FLAGS.use_record: predict_examples = processor.get_test_examples(FLAGS.data_dir) predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") file_based_convert_examples_to_features(predict_examples, label_list, FLAGS.max_seq_length, tokenizer, predict_file) tf.logging.info("***** Running prediction*****") tf.logging.info(" Num examples = %d", len(predict_examples)) tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) if FLAGS.use_tpu: # Warning: According to tpu_estimator.py Prediction on TPU is an # experimental feature and hence not supported here raise ValueError("Prediction in TPU not supported") predict_drop_remainder = True if FLAGS.use_tpu else False predict_input_fn = file_based_input_fn_builder( input_file=predict_file, seq_length=FLAGS.max_seq_length, is_training=False, drop_remainder=predict_drop_remainder) result = estimator.predict(input_fn=predict_input_fn) output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") with tf.gfile.GFile(output_predict_file, "w") as writer: tf.logging.info("***** Predict results *****") for prediction in result: output_line = "\t".join( str(class_probability) for class_probability in prediction) + "\n" writer.write(output_line) print('End', time.localtime())