Пример #1
0
def main(_):

    tf.logging.set_verbosity(tf.logging.INFO)

    processor = raw_data_utils.get_processor(FLAGS.task_name)
    label_list = processor.get_labels()

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file,
                                                     FLAGS.model_dropout)

    tf.gfile.MakeDirs(FLAGS.model_dir)

    flags_dict = tf.app.flags.FLAGS.flag_values_dict()
    with tf.gfile.Open(os.path.join(FLAGS.model_dir, "FLAGS.json"),
                       "w") as ouf:
        json.dump(flags_dict, ouf)

    tf.logging.info("warmup steps {}/{}".format(FLAGS.num_warmup_steps,
                                                FLAGS.num_train_steps))

    save_checkpoints_steps = FLAGS.num_train_steps // FLAGS.save_checkpoints_num
    tf.logging.info("setting save checkpoints steps to {:d}".format(
        save_checkpoints_steps))

    FLAGS.iterations_per_loop = min(save_checkpoints_steps,
                                    FLAGS.iterations_per_loop)
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None
    # if not FLAGS.use_tpu and FLAGS.num_gpu > 1:
    #   train_distribute = tf.contrib.distribute.MirroredStrategy(
    #       num_gpus=FLAGS.num_gpu)
    # else:
    #   train_distribute = None

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        keep_checkpoint_max=1000,
        # train_distribute=train_distribute,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=is_per_host))

    model_fn = uda.model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        clip_norm=FLAGS.clip_norm,
        num_train_steps=FLAGS.num_train_steps,
        num_warmup_steps=FLAGS.num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
        num_labels=len(label_list),
        unsup_ratio=FLAGS.unsup_ratio,
        uda_coeff=FLAGS.uda_coeff,
        tsa=FLAGS.tsa,
        print_feature=False,
        print_structure=False,
    )

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        params={"model_dir": FLAGS.model_dir},
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    if FLAGS.do_train:
        tf.logging.info("  >>> sup data dir : {}".format(
            FLAGS.sup_train_data_dir))
        if FLAGS.unsup_ratio > 0:
            tf.logging.info("  >>> unsup data dir : {}".format(
                FLAGS.unsup_data_dir))

        train_input_fn = proc_data_utils.training_input_fn_builder(
            FLAGS.sup_train_data_dir, FLAGS.unsup_data_dir, FLAGS.aug_ops,
            FLAGS.aug_copy, FLAGS.unsup_ratio)

    if FLAGS.do_eval:
        tf.logging.info("  >>> dev data dir : {}".format(FLAGS.eval_data_dir))
        eval_input_fn = proc_data_utils.evaluation_input_fn_builder(
            FLAGS.eval_data_dir, "clas")

        eval_size = processor.get_dev_size()
        eval_steps = int(eval_size / FLAGS.eval_batch_size)

    if FLAGS.do_train and FLAGS.do_eval:
        tf.logging.info("***** Running training & evaluation *****")
        tf.logging.info("  Supervised batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Unsupervised batch size = %d",
                        FLAGS.train_batch_size * FLAGS.unsup_ratio)
        tf.logging.info("  Num steps = %d", FLAGS.num_train_steps)
        tf.logging.info("  Base evaluation batch size = %d",
                        FLAGS.eval_batch_size)
        tf.logging.info("  Num steps = %d", eval_steps)
        best_acc = 0
        for _ in range(0, FLAGS.num_train_steps, save_checkpoints_steps):
            tf.logging.info("*** Running training ***")
            estimator.train(input_fn=train_input_fn,
                            steps=save_checkpoints_steps)
            tf.logging.info("*** Running evaluation ***")
            dev_result = estimator.evaluate(input_fn=eval_input_fn,
                                            steps=eval_steps)
            tf.logging.info(">> Results:")
            for key in dev_result.keys():
                tf.logging.info("  %s = %s", key, str(dev_result[key]))
                dev_result[key] = dev_result[key].item()
            best_acc = max(best_acc, dev_result["eval_classify_accuracy"])
        tf.logging.info("***** Final evaluation result *****")
        tf.logging.info("Best acc: {:.3f}\n\n".format(best_acc))
    elif FLAGS.do_train:
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Supervised batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Unsupervised batch size = %d",
                        FLAGS.train_batch_size * FLAGS.unsup_ratio)
        tf.logging.info("  Num steps = %d", FLAGS.num_train_steps)
        estimator.train(input_fn=train_input_fn,
                        max_steps=FLAGS.num_train_steps)
    elif FLAGS.do_eval:
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Base evaluation batch size = %d",
                        FLAGS.eval_batch_size)
        tf.logging.info("  Num steps = %d", eval_steps)
        checkpoint_state = tf.train.get_checkpoint_state(FLAGS.model_dir)

        best_acc = 0
        for ckpt_path in checkpoint_state.all_model_checkpoint_paths:
            if not tf.gfile.Exists(ckpt_path + ".data-00000-of-00001"):
                tf.logging.info(
                    "Warning: checkpoint {:s} does not exist".format(
                        ckpt_path))
                continue
            tf.logging.info("Evaluating {:s}".format(ckpt_path))
            dev_result = estimator.evaluate(
                input_fn=eval_input_fn,
                steps=eval_steps,
                checkpoint_path=ckpt_path,
            )
            tf.logging.info(">> Results:")
            for key in dev_result.keys():
                tf.logging.info("  %s = %s", key, str(dev_result[key]))
                dev_result[key] = dev_result[key].item()
            best_acc = max(best_acc, dev_result["eval_classify_accuracy"])
        tf.logging.info("***** Final evaluation result *****")
        tf.logging.info("Best acc: {:.3f}\n\n".format(best_acc))

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.model_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(PaddingInputExample())

        predict_file = os.path.join(FLAGS.model_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)

        output_predict_file = os.path.join(FLAGS.model_dir, "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
Пример #2
0
def main(_):
    hvd.init()
    FLAGS.model_dir = FLAGS.model_dir if hvd.rank() == 0 else os.path.join(
        FLAGS.model_dir, str(hvd.rank()))
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    #FLAGS.num_train_steps = FLAGS.num_train_steps // hvd.size()
    #FLAGS.num_warmup_steps = FLAGS.num_warmup_steps // hvd.size()

    tf.logging.set_verbosity(tf.logging.INFO)

    processor = raw_data_utils.get_processor(FLAGS.task_name)
    label_list = processor.get_labels()

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file,
                                                     FLAGS.model_dropout)

    tf.gfile.MakeDirs(FLAGS.model_dir)

    flags_dict = tf.app.flags.FLAGS.flag_values_dict()
    with tf.gfile.Open(os.path.join(FLAGS.model_dir, "FLAGS.json"),
                       "w") as ouf:
        json.dump(flags_dict, ouf)

    tf.logging.info("warmup steps {}/{}".format(FLAGS.num_warmup_steps,
                                                FLAGS.num_train_steps))

    save_checkpoints_steps = 500  #FLAGS.num_train_steps // FLAGS.save_checkpoints_num
    tf.logging.info("setting save checkpoints steps to {:d}".format(
        save_checkpoints_steps))

    FLAGS.iterations_per_loop = min(save_checkpoints_steps,
                                    FLAGS.iterations_per_loop)
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster_resolver = None
    # if not FLAGS.use_tpu and FLAGS.num_gpu > 1:
    #   train_distribute = tf.contrib.distribute.MirroredStrategy(
    #       num_gpus=FLAGS.num_gpu)
    # else:
    #   train_distribute = None

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        keep_checkpoint_max=1,
        # train_distribute=train_distribute,
        session_config=config,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=is_per_host))

    model_fn = uda.model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        clip_norm=FLAGS.clip_norm,
        num_train_steps=FLAGS.num_train_steps,
        num_warmup_steps=FLAGS.num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
        num_labels=len(label_list),
        unsup_ratio=FLAGS.unsup_ratio,
        uda_coeff=FLAGS.uda_coeff,
        tsa=FLAGS.tsa,
        print_feature=False,
        print_structure=False,
    )

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        params={"model_dir": FLAGS.model_dir},
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.eval_batch_size)

    if FLAGS.do_train:
        tf.logging.info("  >>> sup data dir : {}".format(
            FLAGS.sup_train_data_dir))
        if FLAGS.unsup_ratio > 0:
            tf.logging.info("  >>> unsup data dir : {}".format(
                FLAGS.unsup_data_dir))

        train_input_fn = proc_data_utils.training_input_fn_builder(
            FLAGS.sup_train_data_dir, FLAGS.unsup_data_dir, FLAGS.aug_ops,
            FLAGS.aug_copy, FLAGS.unsup_ratio)
        train_size = processor.get_train_size(FLAGS.raw_data_dir)
        train_steps = int(train_size / FLAGS.train_batch_size)

    if FLAGS.do_eval:
        tf.logging.info("  >>> dev data dir : {}".format(FLAGS.eval_data_dir))
        eval_input_fn = proc_data_utils.evaluation_input_fn_builder(
            FLAGS.eval_data_dir, "clas")

        eval_size = processor.get_dev_size(FLAGS.raw_data_dir)
        eval_steps = int(eval_size / FLAGS.eval_batch_size)

        train_eval_input_fn = proc_data_utils.evaluation_input_fn_builder(
            FLAGS.sup_train_data_dir, "clas")

    if FLAGS.do_train and FLAGS.do_eval:
        hooks = [hvd.BroadcastGlobalVariablesHook(0)]

        tf.logging.info("***** Running training & evaluation *****")
        tf.logging.info("  Supervised batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Unsupervised batch size = %d",
                        FLAGS.train_batch_size * FLAGS.unsup_ratio)
        tf.logging.info("  training size = %d", train_size)
        tf.logging.info("  training num steps = %d", train_steps)
        tf.logging.info("  evaluation batch size = %d", FLAGS.eval_batch_size)
        tf.logging.info("  dev num steps = %d", eval_steps)
        best_acc = 0
        for _ in range(0, FLAGS.num_train_steps, save_checkpoints_steps):
            tf.logging.info("*** Running training ***")
            estimator.train(input_fn=train_input_fn,
                            steps=save_checkpoints_steps,
                            hooks=hooks)
            tf.logging.info("*** Running evaluation ***")

            train_result = estimator.evaluate(input_fn=train_eval_input_fn,
                                              steps=train_steps)
            tf.logging.info(">> Train Results:")
            for key in train_result.keys():
                tf.logging.info("  %s = %s", key, str(train_result[key]))
                train_result[key] = train_result[key].item()
            dev_result = estimator.evaluate(input_fn=eval_input_fn,
                                            steps=eval_steps)
            tf.logging.info(">> Results:")
            for key in dev_result.keys():
                tf.logging.info("  %s = %s", key, str(dev_result[key]))
                dev_result[key] = dev_result[key].item()
            best_acc = max(best_acc, dev_result["eval_precision"])
        tf.logging.info("***** Final evaluation result *****")
        tf.logging.info("Best acc: {:.3f}\n\n".format(best_acc))
    elif FLAGS.do_train:
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Supervised batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Unsupervised batch size = %d",
                        FLAGS.train_batch_size * FLAGS.unsup_ratio)
        tf.logging.info("  Num steps = %d", FLAGS.num_train_steps)
        estimator.train(input_fn=train_input_fn,
                        max_steps=FLAGS.num_train_steps)
    elif FLAGS.do_eval:
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Base evaluation batch size = %d",
                        FLAGS.eval_batch_size)
        tf.logging.info("  Num steps = %d", eval_steps)
        checkpoint_state = tf.train.get_checkpoint_state(FLAGS.model_dir)

        best_acc = 0
        for ckpt_path in checkpoint_state.all_model_checkpoint_paths:
            if not tf.gfile.Exists(ckpt_path + ".data-00000-of-00001"):
                tf.logging.info(
                    "Warning: checkpoint {:s} does not exist".format(
                        ckpt_path))
                continue
            tf.logging.info("Evaluating {:s}".format(ckpt_path))
            dev_result = estimator.evaluate(
                input_fn=eval_input_fn,
                steps=eval_steps,
                checkpoint_path=ckpt_path,
            )
            tf.logging.info(">> Results:")
            for key in dev_result.keys():
                tf.logging.info("  %s = %s", key, str(dev_result[key]))
                dev_result[key] = dev_result[key].item()
            best_acc = max(best_acc, dev_result["eval_precision"])
        tf.logging.info("***** Final evaluation result *****")
        tf.logging.info("Best acc: {:.3f}\n\n".format(best_acc))
        from utils import tokenization
        tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                               do_lower_case=True)
        id2label = dict(zip([i for i in range(len(label_list))], label_list))
        result = estimator.predict(input_fn=eval_input_fn)
        output_line = ""
        with open("label_test.txt", 'w') as writer:
            for re in result:
                sentence = re["input_ids"]
                gold = re["label_ids"]
                prediction = re["predict"]
                # output_line = "\n".join(id2label[id] for id in prediction if id!=0) + "\n"
                for gold_index, gold_item in enumerate(gold):
                    if gold_item >= 34:
                        gold[gold_index] = 0
                for gold_index, gold_item in enumerate(prediction):
                    if gold_item >= 34:
                        gold[gold_index] = 0
                for w, gold_label, label in zip(
                        tokenizer.convert_ids_to_tokens([
                            int(s) for s in sentence
                        ]), [id2label[id] for id in gold],
                    [id2label[id] for id in prediction]):
                    if w == "[PAD]":
                        continue
                    #if label=="NEGATIVE":
                    #  continue
                    output_line = output_line + w + " " + gold_label + " " + label + "\n"
                output_line += "\n"
            writer.write(output_line)
Пример #3
0
Файл: main.py Проект: ys2899/uda
def main(_):

  tf.logging.set_verbosity(tf.logging.INFO)

  processor = raw_data_utils.get_processor(FLAGS.task_name)
  label_list = processor.get_labels()

  bert_config = modeling.BertConfig.from_json_file(
      FLAGS.bert_config_file,
      FLAGS.model_dropout)


  tf.gfile.MakeDirs(FLAGS.model_dir)

  flags_dict = tf.app.flags.FLAGS.flag_values_dict()

  with tf.gfile.Open(os.path.join(FLAGS.model_dir, "FLAGS.json"), "w") as ouf:
    json.dump(flags_dict, ouf)

  tf.logging.info("warmup steps {}/{}".format(
      FLAGS.num_warmup_steps, FLAGS.num_train_steps))

  save_checkpoints_steps = FLAGS.num_train_steps // FLAGS.save_checkpoints_num

  tf.logging.info("setting save checkpoints steps to {:d}".format(
      save_checkpoints_steps))

  FLAGS.iterations_per_loop = min(save_checkpoints_steps,
                                  FLAGS.iterations_per_loop)

  # pdb.set_trace()

  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
  else:
    tpu_cluster_resolver = None
  # if not FLAGS.use_tpu and FLAGS.num_gpu > 1:
  #   train_distribute = tf.contrib.distribute.MirroredStrategy(
  #       num_gpus=FLAGS.num_gpu)
  # else:
  #   train_distribute = None

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      keep_checkpoint_max=1000,
      # train_distribute=train_distribute,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          per_host_input_for_training=is_per_host))

  model_fn = uda.model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      clip_norm=FLAGS.clip_norm,
      num_train_steps=FLAGS.num_train_steps,
      num_warmup_steps=FLAGS.num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
      num_labels=len(label_list),
      unsup_ratio=FLAGS.unsup_ratio,
      uda_coeff=FLAGS.uda_coeff,
      tsa=FLAGS.tsa,
      print_feature=False,
      print_structure=False,
  )

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      params={"model_dir": FLAGS.model_dir},
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size)

  if FLAGS.do_train:
    tf.logging.info("  >>> sup data dir : {}".format(FLAGS.sup_train_data_dir))
    if FLAGS.unsup_ratio > 0:
      tf.logging.info("  >>> unsup data dir : {}".format(
          FLAGS.unsup_data_dir))

    train_input_fn = proc_data_utils.training_input_fn_builder(
        FLAGS.sup_train_data_dir,
        FLAGS.unsup_data_dir,
        FLAGS.aug_ops,
        FLAGS.aug_copy,
        FLAGS.unsup_ratio)

  if FLAGS.do_eval:
    tf.logging.info("  >>> dev data dir : {}".format(FLAGS.eval_data_dir))
    eval_input_fn = proc_data_utils.evaluation_input_fn_builder(
        FLAGS.eval_data_dir,
        "clas")

    eval_size = processor.get_dev_size()
    eval_steps = int(eval_size / FLAGS.eval_batch_size)

  if FLAGS.do_train and FLAGS.do_eval:
    tf.logging.info("***** Running training & evaluation *****")
    tf.logging.info("  Supervised batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Unsupervised batch size = %d",
                    FLAGS.train_batch_size * FLAGS.unsup_ratio)
    tf.logging.info("  Num steps = %d", FLAGS.num_train_steps)
    tf.logging.info("  Base evaluation batch size = %d", FLAGS.eval_batch_size)
    tf.logging.info("  Num steps = %d", eval_steps)
    best_acc = 0
    for _ in range(0, FLAGS.num_train_steps, save_checkpoints_steps):
      tf.logging.info("*** Running training ***")
      pdb.set_trace()
      save_checkpoints_steps = save_checkpoints_steps/15
      estimator.train(
          input_fn=train_input_fn,
          steps=save_checkpoints_steps)
      tf.logging.info("*** Running evaluation ***")
      dev_result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
      tf.logging.info(">> Results:")
      for key in dev_result.keys():
        tf.logging.info("  %s = %s", key, str(dev_result[key]))
        dev_result[key] = dev_result[key].item()
      best_acc = max(best_acc, dev_result["eval_classify_accuracy"])

    tf.logging.info("***** Final evaluation result *****")
    tf.logging.info("Best acc: {:.3f}\n\n".format(best_acc))
  elif FLAGS.do_train:
    tf.logging.info("***** Running training *****")
    tf.logging.info("  Supervised batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Unsupervised batch size = %d",
                    FLAGS.train_batch_size * FLAGS.unsup_ratio)
    tf.logging.info("  Num steps = %d", FLAGS.num_train_steps)
    estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
  elif FLAGS.do_eval:
    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Base evaluation batch size = %d", FLAGS.eval_batch_size)
    tf.logging.info("  Num steps = %d", eval_steps)
    checkpoint_state = tf.train.get_checkpoint_state(FLAGS.model_dir)

    best_acc = 0
    for ckpt_path in checkpoint_state.all_model_checkpoint_paths:
      if not tf.gfile.Exists(ckpt_path + ".data-00000-of-00001"):
        tf.logging.info(
            "Warning: checkpoint {:s} does not exist".format(ckpt_path))
        continue
      tf.logging.info("Evaluating {:s}".format(ckpt_path))
      dev_result = estimator.evaluate(
          input_fn=eval_input_fn,
          steps=eval_steps,
          checkpoint_path=ckpt_path,
      )
      tf.logging.info(">> Results:")
      for key in dev_result.keys():
        tf.logging.info("  %s = %s", key, str(dev_result[key]))
        dev_result[key] = dev_result[key].item()
      best_acc = max(best_acc, dev_result["eval_classify_accuracy"])
    tf.logging.info("***** Final evaluation result *****")
    tf.logging.info("Best acc: {:.3f}\n\n".format(best_acc))