예제 #1
0
def main(_):
    print("-" * 80)
    if not os.path.isdir(FLAGS.output_dir):
        print("Path {} does not exist. Creating.".format(FLAGS.output_dir))
        os.makedirs(FLAGS.output_dir)
    elif FLAGS.reset_output_dir:
        print("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
        shutil.rmtree(FLAGS.output_dir)
        os.makedirs(FLAGS.output_dir)

    print("-" * 80)
    log_file = os.path.join(FLAGS.output_dir, "stdout")
    print("Logging to {}".format(log_file))
    sys.stdout = Logger(log_file)
    print_user_flags()
    train()
예제 #2
0
파일: main.py 프로젝트: Doffery/LightNAS
def main(_):
    logger.info(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
    logger.info("-" * 80)
    # if not os.path.isdir(FLAGS.output_dir):
    #     logger.info("Path {} does not exist. Creating.".format(FLAGS.output_dir))
    #     os.makedirs(FLAGS.output_dir)
    # elif FLAGS.reset_output_dir:
    #     logger.info("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
    #     shutil.rmtree(FLAGS.output_dir)
    #     os.makedirs(FLAGS.output_dir)
 
    # logger.info("-" * 80)
    # log_file = os.path.join(FLAGS.output_dir, "stdout")
    # logger.info("Logging to {}".format(log_file))
    # sys.stdout = Logger(log_file)
 
    utils.print_user_flags()
    train()
    logger.info('End.')
    logger.info(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
예제 #3
0
def train():
  """
  Train driver - no need to change.  
  """
  print("-" * 80)
  if not os.path.isdir(FLAGS.output_dir):
    print("Path {0} does not exist. Creating.".format(FLAGS.output_dir))
    os.makedirs(FLAGS.output_dir)
  elif FLAGS.reset_output_dir:
    print("Path {0} exists. Remove and remake.".format(FLAGS.output_dir))
    shutil.rmtree(FLAGS.output_dir)
    os.makedirs(FLAGS.output_dir)

  print_user_flags()

  # --------------------------------
  #             DATASET
  # --------------------------------
  data_dict = load_dataset(FLAGS.data_path,
                           is_toy=FLAGS.is_toy,
                           num_examples=FLAGS.n_loaded_sentences)
  batched_datasets = create_batched_dataset(data_dict,
                                            batch_size=FLAGS.batch_size)

  # --------------------------------
  #             MODEL
  # --------------------------------
  nmt_model = model(data_dict,
                    checkpoint_dir=FLAGS.output_dir,
                    emb_dim=FLAGS.emb_dim,
                    n_hidden=FLAGS.n_hidden,
                    type=FLAGS.cell_type)
  encoder, decoder = nmt_model['encoder'], nmt_model['decoder']

  # --------------------------------
  #             RESTORE
  # --------------------------------
  # try to restore saved model
  if FLAGS.restore_checkpoint:
    try:
      nmt_model['train_ckpt']. \
        restore(tf.train.latest_checkpoint(FLAGS.output_dir))
      print("\n\nTRYING TO RESTORE THE LATEST MODEL FROM: "
            "'{}'\n\n".format(FLAGS.output_dir))
    except Exception as e:
      print("Error in restoring checkpoint: {}. Training from "
            "scratch".format(e))

  # --------------------------------
  #             TRAIN
  # --------------------------------
  # (continue to) train now
  for epoch in range(1, FLAGS.n_epochs + 1):
    start = time.time()
    total_loss = 0
    total_ppl = 0

    # reset encoder's hidden state at the beginning of every epoch to ZEROs
    hidden = tf.zeros((FLAGS.batch_size, FLAGS.n_hidden))

    for (batch_num, (src, tgt)) in enumerate(batched_datasets['train']):
      loss = 0
      with tf.GradientTape() as tape:
        loss = teacher_forcing(data_dict,
                               encoder,
                               decoder,
                               hidden,
                               loss,
                               src,
                               tgt)# forward
      batch_loss = (loss / int(tgt.shape[1]))
      batch_ppl = np.exp(batch_loss)
      total_loss += batch_loss
      total_ppl += batch_ppl

      # backward
      variables = encoder.variables + decoder.variables
      gradients = tape.gradient(loss, variables)
      nmt_model['optim'].apply_gradients(zip(gradients, variables))

      if batch_num % 100 == 0:
        print('Epoch {} Batch {} Loss {:.4f} '
              'PPL {:.2f}'.format(epoch,
                                  batch_num,
                                  batch_loss.numpy(),
                                  batch_ppl))
      # evaluate some sentences
      if (batch_num + 1) % FLAGS.eval_every == 0:
        sentences = draw_random_sentences(os.path.join(FLAGS.data_path,
                                                       'val.txt'),
                                          size=2)
        for sentence in sentences:
          # change False to True for Attention Plots
          translate(sentence, encoder, decoder, data_dict, FLAGS.is_toy,
                    FLAGS.is_attention, False)

      # saving (checkpoint) the model every 2 epochs
      if (batch_num + 1) % FLAGS.save_every == 0:
        nmt_model['train_ckpt'].save(file_prefix=nmt_model['ckpt_prefix'])
        print("Model saved to '{}".format(FLAGS.output_dir))

    print('Epoch {} Loss {:.4f} PPL {:.2f}'.format(epoch,
                                                   total_loss / batch_num + 1,
                                                   total_ppl / batch_num + 1))
    print("Saving for this epoch")
    nmt_model['train_ckpt'].save(file_prefix=nmt_model['ckpt_prefix'])

    print('Time taken for this epoch {:.2f} sec\n'.format(time.time() - start))
    print("-" * 80)
    sys.stdout.flush()

    # final test after training
    # you can inject this into the loop to do it at every epoch's end
    if epoch == FLAGS.n_epochs:
      print("Testing the whole test set for epoch {}".format(epoch))
      test_lines = open(os.path.join(FLAGS.data_path, 'target.txt')).\
        read().strip().split('\n')
      with open(os.path.join(FLAGS.data_path, 'translated_{}.txt'.format(epoch)),
                'w') as f:
        for line in test_lines:
          result, _, _ = evaluate(line, encoder, decoder, data_dict,
                                  FLAGS.is_toy, FLAGS.is_attention)
          f.write(' '.join(result.split()[:-1]) + '\n')
def main(_):
  print_user_flags()
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {
      "ar": ARProcessor,
  }

  if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
    raise ValueError(
        "At least one of `do_train`, `do_eval` or `do_predict' must be True.")

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  tf.gfile.MakeDirs(FLAGS.output_dir)

  task_name = FLAGS.task_name.lower()

  if task_name not in processors:
    raise ValueError("Task not found: %s" % (task_name))

  processor = processors[task_name]()

  label_list = processor.get_labels()

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_tpu_cores,
          per_host_input_for_training=is_per_host))

  print('before train', time.localtime())

  train_examples = None
  eval_examples = None
  num_train_steps = None
  num_warmup_steps = None
  if FLAGS.do_train:
    train_examples = processor.get_train_examples(FLAGS.data_dir)
    # if FLAGS.do_eval:
    #     valsize = int(len(train_examples) / 9)  # take 10% as validation
    #     eval_examples = train_examples[:valsize]
    #     train_examples = train_examples[valsize:]
    num_train_steps = int(
        len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      num_labels=len(label_list),
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      predict_batch_size=FLAGS.predict_batch_size)

  if FLAGS.do_train:
    train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
    if not FLAGS.use_record:
        file_based_convert_examples_to_features(train_examples, label_list, 
                                                FLAGS.max_seq_length, 
                                                tokenizer, train_file)
    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num examples = %d", len(train_examples))
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)

    val_file = os.path.join(FLAGS.output_dir, "val.tf_record")
    val_examples = processor.get_dev_examples(FLAGS.data_dir)
    if not FLAGS.use_record:
        file_based_convert_examples_to_features(val_examples, label_list, 
                                                FLAGS.max_seq_length, 
                                                tokenizer, val_file)

    tf.logging.info("  Num val examples = %d", len(val_examples))
    tf.logging.info("  Batch size val = %d", FLAGS.eval_batch_size)

    val_drop_remainder = True if FLAGS.use_tpu else False
    val_input_fn = file_based_input_fn_builder(
        input_file=val_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=val_drop_remainder)

    train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                        max_steps=num_train_steps)
    val_spec = tf.estimator.EvalSpec(input_fn=val_input_fn)


    tf.estimator.train_and_evaluate(estimator, train_spec, val_spec)
    # estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
    # result = estimator.evaluate(input_fn=val_input_fn, steps=val_steps)

  print('before eval', time.localtime())

  if FLAGS.do_eval:
    eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
    if eval_examples == None:
        eval_examples = processor.get_mytest_examples(FLAGS.data_dir)
    if not FLAGS.use_record:
        file_based_convert_examples_to_features(eval_examples, label_list, 
                                                FLAGS.max_seq_length, 
                                                tokenizer, eval_file)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d", len(eval_examples))
    tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    # This tells the estimator to run through the entire set.
    eval_steps = None
    # However, if running eval on the TPU, you will need to specify the
    # number of steps.
    if FLAGS.use_tpu:
      # Eval will be slightly WRONG on the TPU because it will truncate
      # the last batch.
      eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)

    eval_drop_remainder = True if FLAGS.use_tpu else False
    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=eval_drop_remainder)

    result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)

    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    with tf.gfile.GFile(output_eval_file, "w") as writer:
      tf.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        tf.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))

  print('before predict', time.localtime())

  if FLAGS.do_predict:
    if not FLAGS.use_record:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer,
                                                predict_file)

    tf.logging.info("***** Running prediction*****")
    tf.logging.info("  Num examples = %d", len(predict_examples))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    if FLAGS.use_tpu:
      # Warning: According to tpu_estimator.py Prediction on TPU is an
      # experimental feature and hence not supported here
      raise ValueError("Prediction in TPU not supported")

    predict_drop_remainder = True if FLAGS.use_tpu else False
    predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=predict_drop_remainder)

    result = estimator.predict(input_fn=predict_input_fn)

    output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
    with tf.gfile.GFile(output_predict_file, "w") as writer:
      tf.logging.info("***** Predict results *****")
      for prediction in result:
        output_line = "\t".join(
            str(class_probability) for class_probability in prediction) + "\n"
        writer.write(output_line)

  print('End', time.localtime())