示例#1
0
    def train_and_eval(self, local_training=True):
        """Launches train and evaluation jobs locally."""
        # TODO: supports for distributed training and evaluation.
        if not local_training:
            raise ValueError("The non local training is not supported now!")

        train_spec, eval_spec, _ = self._train_eval_specs()
        tf_estimator.train_and_evaluate(self._estimator, train_spec, eval_spec)
示例#2
0
def train_and_eval():
    """Train and Evaluate."""
    train_input_fn = make_input_fn(FLAGS.train_path, FLAGS.batch_size)
    eval_input_fn = make_input_fn(FLAGS.eval_path,
                                  FLAGS.batch_size,
                                  randomize_input=False,
                                  num_epochs=1)

    optimizer = tf.compat.v1.train.AdagradOptimizer(
        learning_rate=FLAGS.learning_rate)

    def _train_op_fn(loss):
        """Defines train op used in ranking head."""
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)
        minimize_op = optimizer.minimize(
            loss=loss, global_step=tf.compat.v1.train.get_global_step())
        train_op = tf.group([minimize_op, update_ops])
        return train_op

    ranking_head = tfr.head.create_ranking_head(
        loss_fn=tfr.losses.make_loss_fn(
            FLAGS.loss, weights_feature_name=FLAGS.weights_feature_name),
        eval_metric_fns=eval_metric_fns(),
        train_op_fn=_train_op_fn)

    estimator = tf_estimator.Estimator(
        model_fn=tfr.model.make_groupwise_ranking_fn(
            group_score_fn=make_score_fn(),
            group_size=FLAGS.group_size,
            transform_fn=make_transform_fn(),
            ranking_head=ranking_head),
        model_dir=FLAGS.model_dir,
        config=tf_estimator.RunConfig(save_checkpoints_steps=1000))

    train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn,
                                        max_steps=FLAGS.num_train_steps)

    exporters = tf_estimator.LatestExporter(
        "saved_model_exporter",
        serving_input_receiver_fn=make_serving_input_fn())

    eval_spec = tf_estimator.EvalSpec(name="eval",
                                      input_fn=eval_input_fn,
                                      steps=1,
                                      exporters=exporters,
                                      start_delay_secs=0,
                                      throttle_secs=15)

    # Train and validate.
    tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
示例#3
0
文件: task.py 项目: zhoufek/beam
def train_and_maybe_evaluate(hparams):
  """Run the training and evaluate using the high level API.

  Args:
    hparams: Holds hyperparameters used to train the model as name/value pairs.

  Returns:
    The estimator that was used for training (and maybe eval)
  """
  schema = taxi.read_schema(hparams.schema_file)
  tf_transform_output = tft.TFTransformOutput(hparams.tf_transform_dir)

  train_input = lambda: model.input_fn(
      hparams.train_files, tf_transform_output, batch_size=TRAIN_BATCH_SIZE)

  eval_input = lambda: model.input_fn(
      hparams.eval_files, tf_transform_output, batch_size=EVAL_BATCH_SIZE)

  train_spec = tf_estimator.TrainSpec(
      train_input, max_steps=hparams.train_steps)

  serving_receiver_fn = lambda: model.example_serving_receiver_fn(
      tf_transform_output, schema)

  exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
  eval_spec = tf_estimator.EvalSpec(
      eval_input,
      steps=hparams.eval_steps,
      exporters=[exporter],
      name='chicago-taxi-eval')

  run_config = tf_estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  serving_model_dir = os.path.join(hparams.output_dir, SERVING_MODEL_DIR)
  run_config = run_config.replace(model_dir=serving_model_dir)

  estimator = model.build_estimator(
      tf_transform_output,

      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(FIRST_DNN_LAYER_SIZE * DNN_DECAY_FACTOR**i))
          for i in range(NUM_DNN_LAYERS)
      ],
      config=run_config)

  tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)

  return estimator
示例#4
0
    def test_model_to_estimator_missing_custom_objects(self):
        keras_model = model.create_keras_model(network=self._network,
                                               loss=self._loss,
                                               metrics=self._eval_metrics,
                                               optimizer=self._optimizer,
                                               size_feature_name=_SIZE)
        estimator = estimator_lib.model_to_estimator(model=keras_model,
                                                     config=self._config,
                                                     custom_objects=None)
        self.assertIsInstance(estimator, tf_compat_v1_estimator.Estimator)

        # Train and export model.
        train_spec = tf_estimator.TrainSpec(input_fn=self._make_input_fn(),
                                            max_steps=1)
        eval_spec = tf_estimator.EvalSpec(name='eval',
                                          input_fn=self._make_input_fn(),
                                          steps=10)

        with self.assertRaises(AttributeError):
            tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
示例#5
0
    def test_experiment(self, listwise_inference):
        tmp_dir = self.create_tempdir().full_path
        data_file = os.path.join(tmp_dir, "elwc.tfrecord")
        with tf.io.TFRecordWriter(data_file) as writer:
            for _ in range(10):
                writer.write(ELWC_PROTO.SerializeToString())

        estimator = tfr_estimator.make_gam_ranking_estimator(
            example_feature_columns=example_feature_columns(),
            example_hidden_units=["2", "2"],
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.05),
            learning_rate=0.05,
            loss="softmax_loss",
            use_batch_norm=False,
            model_dir=None)
        train_spec = tf_estimator.TrainSpec(input_fn=_inner_input_fn,
                                            max_steps=1)
        eval_spec = tf_estimator.EvalSpec(name="eval",
                                          input_fn=_inner_input_fn,
                                          steps=10)
        tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def train():
    encoders = get_encoders()

    generator = get_generator(encoders)

    checkpoint_dir = path.join(config.CHECKPOINT_DIR, "tf")
    estimator = K.estimator.model_to_estimator(
        keras_model=generator,
        model_dir=checkpoint_dir,
        checkpoint_format="saver"
    )  # TODO: use 'checkpoint' once object-based checkpoints supported

    def input_fn():
        dataset = get_dataset(encoders)
        return dataset.batch(config.BATCH_SIZE)

    train_spec = E.TrainSpec(input_fn=input_fn)
    eval_spec = E.EvalSpec(
        input_fn=input_fn,
        hooks=[E.CheckpointSaverHook(checkpoint_dir, save_steps=1000)])

    E.train_and_evaluate(estimator, train_spec, eval_spec)
示例#7
0
def train_and_eval():
    """Train and Evaluate."""
    train_input_fn = make_input_fn(FLAGS.train_path, FLAGS.batch_size)
    eval_input_fn = make_input_fn(FLAGS.eval_path,
                                  FLAGS.batch_size,
                                  randomize_input=False,
                                  num_epochs=1)

    estimator = get_estimator()
    train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn,
                                        max_steps=FLAGS.num_train_steps)
    exporters = tf_estimator.LatestExporter(
        "saved_model_exporter",
        serving_input_receiver_fn=make_serving_input_fn())
    eval_spec = tf_estimator.EvalSpec(name="eval",
                                      input_fn=eval_input_fn,
                                      steps=1,
                                      exporters=exporters,
                                      start_delay_secs=0,
                                      throttle_secs=15)

    # Train and validate.
    tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
示例#8
0
 def __call__(self):
     train_spec = estimator.TrainSpec(input_fn=lambda: self._input_fn(),
                                      max_steps=self._iter_num)
     eval_spec = estimator.EvalSpec(input_fn=lambda: self._input_fn(), )
     return estimator.train_and_evaluate(self._estimator,
                                         train_spec=train_spec, eval_spec=eval_spec)
示例#9
0
def train_and_eval():
    """Train and Evaluate."""

    features, labels = load_libsvm_data(FLAGS.train_path, FLAGS.list_size)
    train_input_fn, train_hook = get_train_inputs(features, labels,
                                                  FLAGS.train_batch_size)

    features_vali, labels_vali = load_libsvm_data(FLAGS.vali_path,
                                                  FLAGS.list_size)
    vali_input_fn, vali_hook = get_eval_inputs(features_vali, labels_vali)

    features_test, labels_test = load_libsvm_data(FLAGS.test_path,
                                                  FLAGS.list_size)
    test_input_fn, test_hook = get_eval_inputs(features_test, labels_test)

    optimizer = tf.compat.v1.train.AdagradOptimizer(
        learning_rate=FLAGS.learning_rate)

    def _train_op_fn(loss):
        """Defines train op used in ranking head."""
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)
        minimize_op = optimizer.minimize(
            loss=loss, global_step=tf.compat.v1.train.get_global_step())
        train_op = tf.group([minimize_op, update_ops])
        return train_op

    if _use_multi_head():
        primary_head = tfr.head.create_ranking_head(
            loss_fn=tfr.losses.make_loss_fn(FLAGS.loss),
            eval_metric_fns=get_eval_metric_fns(),
            train_op_fn=_train_op_fn,
            name=_PRIMARY_HEAD)
        secondary_head = tfr.head.create_ranking_head(
            loss_fn=tfr.losses.make_loss_fn(FLAGS.secondary_loss),
            eval_metric_fns=get_eval_metric_fns(),
            train_op_fn=_train_op_fn,
            name=_SECONDARY_HEAD)
        ranking_head = tfr.head.create_multi_ranking_head(
            [primary_head, secondary_head], [1.0, FLAGS.secondary_loss_weight])
    else:
        ranking_head = tfr.head.create_ranking_head(
            loss_fn=tfr.losses.make_loss_fn(FLAGS.loss),
            eval_metric_fns=get_eval_metric_fns(),
            train_op_fn=_train_op_fn)

    estimator = tf_estimator.Estimator(
        model_fn=tfr.model.make_groupwise_ranking_fn(
            group_score_fn=make_score_fn(),
            group_size=FLAGS.group_size,
            transform_fn=make_transform_fn(),
            ranking_head=ranking_head),
        config=tf_estimator.RunConfig(FLAGS.output_dir,
                                      save_checkpoints_steps=1000))

    train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn,
                                        hooks=[train_hook],
                                        max_steps=FLAGS.num_train_steps)
    # Export model to accept tf.Example when group_size = 1.
    if FLAGS.group_size == 1:
        vali_spec = tf_estimator.EvalSpec(
            input_fn=vali_input_fn,
            hooks=[vali_hook],
            steps=1,
            exporters=tf_estimator.LatestExporter(
                "latest_exporter",
                serving_input_receiver_fn=make_serving_input_fn()),
            start_delay_secs=0,
            throttle_secs=30)
    else:
        vali_spec = tf_estimator.EvalSpec(input_fn=vali_input_fn,
                                          hooks=[vali_hook],
                                          steps=1,
                                          start_delay_secs=0,
                                          throttle_secs=30)

    # Train and validate
    tf_estimator.train_and_evaluate(estimator, train_spec, vali_spec)

    # Evaluate on the test data.
    estimator.evaluate(input_fn=test_input_fn, hooks=[test_hook])
示例#10
0
def train_bert_multitask(
        problem='weibo_ner',
        num_gpus=1,
        num_epochs=10,
        model_dir='',
        params=None,
        problem_type_dict={},
        processing_fn_dict={},
        model=None):
    """Train Multi-task Bert model

    About problem: 
        There are two types of chaining operations can be used to chain problems.
            - `&`. If two problems have the same inputs, they can be chained using `&`. Problems chained by `&` will be trained at the same time.
            - `|`. If two problems don't have the same inputs, they need to be chained using `|`. Problems chained by `|` will be sampled to train at every instance.

        For example, `cws|NER|weibo_ner&weibo_cws`, one problem will be sampled at each turn, say `weibo_ner&weibo_cws`, then `weibo_ner` and `weibo_cws` will trained for this turn together. Therefore, in a particular batch, some tasks might not be sampled, and their loss could be 0 in this batch.

    About problem_type_dict and processing_fn_dict:
        If the problem is not predefined, you need to tell the model what's the new problem's problem_type
        and preprocessing function.
            For example, a new problem: fake_classification
            problem_type_dict = {'fake_classification': 'cls'}
            processing_fn_dict = {'fake_classification': lambda: return ...}

        Available problem type:
            cls: Classification
            seq_tag: Sequence Labeling
            seq2seq_tag: Sequence to Sequence tag problem
            seq2seq_text: Sequence to Sequence text generation problem

        Preprocessing function example:
        Please refer to https://github.com/JayYip/bert-multitask-learning/blob/master/README.md

    Keyword Arguments:
        problem {str} -- Problems to train (default: {'weibo_ner'})
        num_gpus {int} -- Number of GPU to use (default: {1})
        num_epochs {int} -- Number of epochs to train (default: {10})
        model_dir {str} -- model dir (default: {''})
        params {BaseParams} -- Params to define training and models (default: {DynamicBatchSizeParams()})
        problem_type_dict {dict} -- Key: problem name, value: problem type (default: {{}})
        processing_fn_dict {dict} -- Key: problem name, value: problem data preprocessing fn (default: {{}})
    """
    if params is None:
        params = DynamicBatchSizeParams()

    if not os.path.exists('models'):
        os.mkdir('models')

    if model_dir:
        base_dir, dir_name = os.path.split(model_dir)
    else:
        base_dir, dir_name = None, None
    params.train_epoch = num_epochs
    # add new problem to params if problem_type_dict and processing_fn_dict provided
    if processing_fn_dict:
        for new_problem, new_problem_processing_fn in processing_fn_dict.items():
            print('Adding new problem {0}, problem type: {1}'.format(
                new_problem, problem_type_dict[new_problem]))
            params.add_problem(
                problem_name=new_problem, problem_type=problem_type_dict[new_problem], processing_fn=new_problem_processing_fn)
    params.assign_problem(problem, gpu=int(num_gpus),
                          base_dir=base_dir, dir_name=dir_name)
    params.to_json()

    estimator = _create_estimator(
        num_gpus=num_gpus, params=params, model=model)

    train_hook = RestoreCheckpointHook(params)

    def train_input_fn(): return train_eval_input_fn(params)
    def eval_input_fn(): return train_eval_input_fn(params, mode=EVAL)

    train_spec = TrainSpec(
        input_fn=train_input_fn, max_steps=params.train_steps, hooks=[train_hook])
    eval_spec = EvalSpec(
        eval_input_fn, throttle_secs=params.eval_throttle_secs)

    # estimator.train(
    #     train_input_fn, max_steps=params.train_steps, hooks=[train_hook])
    train_and_evaluate(estimator, train_spec, eval_spec)
    return estimator
示例#11
0
pprint(params)
print('placing model artifacts in {}'.format(model_dir))

# define model and data
model = MNISTModel()
mnist_data = Mnist(params['batch_size'])

run_config = estimator.RunConfig(
    save_checkpoints_steps=params['steps_per_epoch'],
    save_summary_steps=200,
    keep_checkpoint_max=10)

mnist_estimator = estimator.Estimator(model_dir=model_dir,
                                      model_fn=model.model_fn,
                                      params=params,
                                      config=run_config)

# training/evaluation specs for run
train_spec = estimator.TrainSpec(
    input_fn=mnist_data.build_training_data,
    max_steps=params['total_steps_train'],
)

eval_spec = estimator.EvalSpec(input_fn=mnist_data.build_validation_data,
                               steps=None,
                               throttle_secs=params['throttle_eval'],
                               start_delay_secs=0)

# run train and evaluate
estimator.train_and_evaluate(mnist_estimator, train_spec, eval_spec)
示例#12
0
    def test_model_to_estimator(self, weights_feature_name, serving_default):
        keras_model = model.create_keras_model(network=self._network,
                                               loss=self._loss,
                                               metrics=self._eval_metrics,
                                               optimizer=self._optimizer,
                                               size_feature_name=_SIZE)
        estimator = estimator_lib.model_to_estimator(
            model=keras_model,
            config=self._config,
            weights_feature_name=weights_feature_name,
            custom_objects=self._custom_objects,
            serving_default=serving_default)
        self.assertIsInstance(estimator, tf_compat_v1_estimator.Estimator)

        # Train and export model.
        train_spec = tf_estimator.TrainSpec(
            input_fn=self._make_input_fn(weights_feature_name), max_steps=1)
        eval_spec = tf_estimator.EvalSpec(
            name='eval',
            input_fn=self._make_input_fn(weights_feature_name),
            steps=10)
        tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)

        context_feature_spec = tf.feature_column.make_parse_example_spec(
            self._context_feature_columns.values())
        example_feature_spec = tf.feature_column.make_parse_example_spec(
            self._example_feature_columns.values())

        def _make_serving_input_fn(serving_default):
            if serving_default == 'predict':
                return data.build_ranking_serving_input_receiver_fn(
                    data.ELWC,
                    context_feature_spec=context_feature_spec,
                    example_feature_spec=example_feature_spec,
                    size_feature_name=_SIZE)
            else:

                def pointwise_serving_fn():
                    serialized = tf.compat.v1.placeholder(
                        dtype=tf.string,
                        shape=[None],
                        name='input_ranking_tensor')
                    receiver_tensors = {'input_ranking_data': serialized}
                    features = data.parse_from_tf_example(
                        serialized,
                        context_feature_spec=context_feature_spec,
                        example_feature_spec=example_feature_spec,
                        size_feature_name=_SIZE)
                    return tf_estimator.export.ServingInputReceiver(
                        features, receiver_tensors)

                return pointwise_serving_fn

        serving_input_receiver_fn = _make_serving_input_fn(serving_default)
        export_dir = os.path.join(tf.compat.v1.test.get_temp_dir(), 'export')
        estimator.export_saved_model(export_dir, serving_input_receiver_fn)

        # Confirm model ran and created checkpoints and saved model.
        final_ckpt_path = os.path.join(estimator.model_dir,
                                       'model.ckpt-1.meta')
        self.assertTrue(tf.io.gfile.exists(final_ckpt_path))

        saved_model_pb = os.path.join(export_dir,
                                      tf.io.gfile.listdir(export_dir)[0],
                                      'saved_model.pb')
        self.assertTrue(tf.io.gfile.exists(saved_model_pb))
示例#13
0
def train_and_eval(params,
                   model_fn,
                   input_fn,
                   keep_checkpoint_every_n_hours=0.5,
                   save_checkpoints_secs=100,
                   eval_steps=0,
                   eval_start_delay_secs=10,
                   eval_throttle_secs=100,
                   save_summary_steps=50):
    """Trains and evaluates our model.

  Supports local and distributed training.

  Args:
    params: ConfigParams class with model training and network parameters.
    model_fn: A func with prototype model_fn(features, labels, mode, hparams).
    input_fn: A input function for the tf.estimator.Estimator.
    keep_checkpoint_every_n_hours: Number of hours between each checkpoint to be
      saved.
    save_checkpoints_secs: Save checkpoints every this many seconds.
    eval_steps: Number of steps to evaluate model; 0 for one epoch.
    eval_start_delay_secs: Start evaluating after waiting for this many seconds.
    eval_throttle_secs: Do not re-evaluate unless the last evaluation was
      started at least this many seconds ago
    save_summary_steps: Save summaries every this many steps.
  """

    mparams = params.model_params

    run_config = tf_estimator.RunConfig(
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
        save_checkpoints_secs=save_checkpoints_secs,
        save_summary_steps=save_summary_steps)

    if run_config.model_dir:
        params.model_dir = run_config.model_dir
    print('\nCreating estimator with model dir %s' % params.model_dir)
    estimator = tf_estimator.Estimator(model_fn=model_fn,
                                       model_dir=params.model_dir,
                                       config=run_config,
                                       params=params)

    print('\nCreating train_spec')
    train_spec = tf_estimator.TrainSpec(input_fn=input_fn(params,
                                                          split='train'),
                                        max_steps=params.steps)

    print('\nCreating eval_spec')

    def serving_input_receiver_fn():
        """Serving input_fn that builds features from placeholders.

    Returns:
      A tf.estimator.export.ServingInputReceiver.
    """
        modelx = mparams.modelx
        modely = mparams.modely
        offsets = keras.Input(shape=(3, ), name='offsets', dtype='float32')
        hom = keras.Input(shape=(3, 3), name='hom', dtype='float32')
        to_world = keras.Input(shape=(4, 4),
                               name='to_world_L',
                               dtype='float32')
        img_l = keras.Input(shape=(modely, modelx, 3),
                            name='img_L',
                            dtype='float32')
        img_r = keras.Input(shape=(modely, modelx, 3),
                            name='img_R',
                            dtype='float32')
        features = {
            'img_L': img_l,
            'img_R': img_r,
            'to_world_L': to_world,
            'offsets': offsets,
            'hom': hom
        }
        return tf_estimator.export.build_raw_serving_input_receiver_fn(
            features)

    class SaveModel(tf_estimator.SessionRunHook):
        """Saves a model in SavedModel format."""
        def __init__(self, estimator, output_dir):
            self.output_dir = output_dir
            self.estimator = estimator
            self.save_num = 0

        def begin(self):
            ckpt = self.estimator.latest_checkpoint()
            print('Latest checkpoint in hook:', ckpt)
            ckpt_num_str = ckpt.split('.ckpt-')[1]
            if (int(ckpt_num_str) - self.save_num) > 4000:
                fname = os.path.join(self.output_dir,
                                     'saved_model-' + ckpt_num_str)
                print('**** Saving model in train hook: %s' % fname)
                self.estimator.export_saved_model(fname,
                                                  serving_input_receiver_fn())
                self.save_num = int(ckpt_num_str)

    saver_hook = SaveModel(estimator, params.model_dir)

    if eval_steps == 0:
        eval_steps = None
    eval_spec = tf_estimator.EvalSpec(input_fn=input_fn(params, split='val'),
                                      steps=eval_steps,
                                      hooks=[saver_hook],
                                      start_delay_secs=eval_start_delay_secs,
                                      throttle_secs=eval_throttle_secs)

    if run_config.is_chief:
        outdir = params.model_dir
        if outdir is not None:
            print('Writing params to %s' % outdir)
            os.makedirs(outdir, exist_ok=True)
            params.write_yaml(os.path.join(outdir, 'params.yaml'))

    print('\nRunning estimator')
    tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)

    print('\nSaving last model')
    ckpt = estimator.latest_checkpoint()
    print('Last checkpoint:', ckpt)
    ckpt_num_str = ckpt.split('.ckpt-')[1]
    fname = os.path.join(params.model_dir, 'saved_model-' + ckpt_num_str)
    print('**** Saving last model: %s' % fname)
    estimator.export_saved_model(fname, serving_input_receiver_fn())
def main(_):
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

  # Load emotion categories
  with open(FLAGS.emotion_file, "r") as f:
    all_emotions = f.read().splitlines()
    if FLAGS.add_neutral:
      all_emotions = all_emotions + ["neutral"]
    idx2emotion = {i: e for i, e in enumerate(all_emotions)}
  num_labels = len(all_emotions)
  print("%d labels" % num_labels)
  print("Multilabel: %r" % FLAGS.multilabel)

  tf.logging.set_verbosity(tf.logging.INFO)

  tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)

  if not FLAGS.do_train and not FLAGS.do_predict:
    raise ValueError("At least one of `do_train` or `do_predict' must be True.")

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  tf.gfile.MakeDirs(FLAGS.output_dir)

  processor = DataProcessor(num_labels, FLAGS.data_dir)  # set up preprocessor

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  run_config = tf_estimator.RunConfig(
      model_dir=FLAGS.output_dir,
      save_summary_steps=FLAGS.save_summary_steps,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=FLAGS.keep_checkpoint_max)

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None

  if FLAGS.do_train:
    train_examples = processor.get_examples("train", FLAGS.train_fname)
    eval_examples = processor.get_examples("dev", FLAGS.dev_fname)
    num_eval_examples = len(eval_examples)
    num_train_steps = int(
        len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    params = {
        "num_labels": num_labels,
        "learning_rate": FLAGS.learning_rate,
        "num_train_epochs": FLAGS.num_train_epochs,
        "warmup_proportion": FLAGS.warmup_proportion,
        "batch_size": FLAGS.train_batch_size,
        "num_train_examples": len(train_examples),
        "num_eval_examples": num_eval_examples,
        "data_dir": FLAGS.data_dir,
        "output_dir": FLAGS.output_dir,
        "train_fname": FLAGS.train_fname,
        "dev_fname": FLAGS.dev_fname,
        "test_fname": FLAGS.test_fname
    }
    with open(os.path.join(FLAGS.output_dir, "config.json"), "w") as f:
      json.dump(params, f)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      num_labels=num_labels,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      multilabel=FLAGS.multilabel,
      idx2emotion=idx2emotion)

  estimator = tf_estimator.Estimator(
      model_fn=model_fn,
      config=run_config,
      params={"batch_size": FLAGS.train_batch_size})

  if FLAGS.do_train:
    train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
    file_based_convert_examples_to_features(train_examples,
                                            FLAGS.max_seq_length, tokenizer,
                                            train_file)
    eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
    file_based_convert_examples_to_features(eval_examples, FLAGS.max_seq_length,
                                            tokenizer, eval_file)

    tf.logging.info("***** Running training and evaluation *****")
    tf.logging.info("  Num train examples = %d", len(train_examples))
    tf.logging.info("  Num eval examples = %d", num_eval_examples)
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
    tf.logging.info("  Num training steps = %d", num_train_steps)
    train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True,
        num_labels=num_labels)
    train_spec = tf_estimator.TrainSpec(
        input_fn=train_input_fn, max_steps=num_train_steps)
    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False,
        num_labels=num_labels)
    eval_spec = tf_estimator.EvalSpec(
        input_fn=eval_input_fn,
        steps=FLAGS.eval_steps,
        start_delay_secs=0,
        throttle_secs=1000)

    tf_estimator.train_and_evaluate(
        estimator, train_spec=train_spec, eval_spec=eval_spec)

  if FLAGS.calculate_metrics:

    # Setting the parameter to "dev" ensures that we get labels for the examples
    eval_examples = processor.get_examples("dev", FLAGS.test_fname)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num eval examples = %d", len(eval_examples))
    eval_file = os.path.join(FLAGS.output_dir, FLAGS.test_fname + ".tf_record")
    file_based_convert_examples_to_features(eval_examples, FLAGS.max_seq_length,
                                            tokenizer, eval_file)
    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False,
        num_labels=num_labels)

    result = estimator.evaluate(input_fn=eval_input_fn, steps=None)
    output_eval_file = os.path.join(FLAGS.output_dir,
                                    FLAGS.test_fname + ".eval_results.txt")
    with tf.gfile.GFile(output_eval_file, "w") as writer:
      tf.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        tf.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))

  if FLAGS.do_predict:
    predict_examples = processor.get_examples("test", FLAGS.test_fname)
    num_actual_predict_examples = len(predict_examples)

    predict_file = os.path.join(FLAGS.output_dir,
                                FLAGS.test_fname + ".tf_record")
    file_based_convert_examples_to_features(predict_examples,
                                            FLAGS.max_seq_length, tokenizer,
                                            predict_file)

    tf.logging.info("***** Running prediction*****")
    tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                    len(predict_examples), num_actual_predict_examples,
                    len(predict_examples) - num_actual_predict_examples)
    tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)

    predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=False,
        num_labels=num_labels)

    result = estimator.predict(input_fn=predict_input_fn)

    output_predict_file = os.path.join(FLAGS.output_dir,
                                       FLAGS.test_fname + ".predictions.tsv")
    output_labels = os.path.join(FLAGS.output_dir,
                                 FLAGS.test_fname + ".label_predictions.tsv")

    with tf.gfile.GFile(output_predict_file, "w") as writer:
      with tf.gfile.GFile(output_labels, "w") as writer2:
        writer.write("\t".join(all_emotions) + "\n")
        writer2.write("\t".join([
            "text", "emotion_1", "prob_1", "emotion_2", "prob_2", "emotion_3",
            "prob_3"
        ]) + "\n")
        tf.logging.info("***** Predict results *****")
        num_written_lines = 0
        for (i, prediction) in enumerate(result):
          probabilities = prediction["probabilities"]
          if i >= num_actual_predict_examples:
            break
          output_line = "\t".join(
              str(class_probability)
              for class_probability in probabilities) + "\n"
          sorted_idx = np.argsort(-probabilities)
          top_3_emotion = [idx2emotion[idx] for idx in sorted_idx[:3]]
          top_3_prob = [probabilities[idx] for idx in sorted_idx[:3]]
          pred_line = []
          for emotion, prob in zip(top_3_emotion, top_3_prob):
            if prob >= FLAGS.pred_cutoff:
              pred_line.extend([emotion, "%.4f" % prob])
            else:
              pred_line.extend(["", ""])
          writer.write(output_line)
          writer2.write(predict_examples[i].text + "\t" + "\t".join(pred_line) +
                        "\n")
          num_written_lines += 1
    assert num_written_lines == num_actual_predict_examples