def _create_estimator(self, params):
        tf.logging.info("Setting random seed to {}".format(42))
        np.random.seed(42)

        # Small bert model for testing.
        bert_config = modeling.BertConfig.from_dict({
            "vocab_size":
            10,
            "type_vocab_size": [3, 256, 256, 2, 256, 256, 10],
            "num_hidden_layers":
            2,
            "num_attention_heads":
            2,
            "hidden_size":
            128,
            "intermediate_size":
            512,
        })
        model_fn = tapas_pretraining_model.model_fn_builder(
            bert_config=bert_config,
            init_checkpoint=params["init_checkpoint"],
            learning_rate=params["learning_rate"],
            num_train_steps=params["num_train_steps"],
            num_warmup_steps=params["num_warmup_steps"],
            use_tpu=params["use_tpu"])

        estimator = tf.estimator.tpu.TPUEstimator(
            use_tpu=params["use_tpu"],
            model_fn=model_fn,
            config=tf.estimator.tpu.RunConfig(
                model_dir=self.get_temp_dir(),
                save_summary_steps=params["num_train_steps"],
                save_checkpoints_steps=params["num_train_steps"]),
            train_batch_size=params["batch_size"],
            predict_batch_size=params["batch_size"],
            eval_batch_size=params["batch_size"])

        return estimator
def main(_):
    bert_config = experiment_utils.bert_config_from_flags()
    model_fn = tapas_pretraining_model.model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=experiment_utils.num_train_steps(),
        num_warmup_steps=experiment_utils.num_warmup_steps(),
        use_tpu=FLAGS.use_tpu,
        restrict_attention_mode=attention_utils.RestrictAttentionMode(
            FLAGS.restrict_attention_mode),
        restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size,
        restrict_attention_header_size=FLAGS.restrict_attention_header_size,
        restrict_attention_row_heads_ratio=(
            FLAGS.restrict_attention_row_heads_ratio),
        disabled_features=FLAGS.disabled_features,
        disable_position_embeddings=FLAGS.disable_position_embeddings,
        reset_position_index_per_cell=FLAGS.reset_position_index_per_cell,
        proj_value_length=FLAGS.proj_value_length
        if FLAGS.proj_value_length > 0 else None,
        attention_bias_disabled=FLAGS.attention_bias_disabled,
        attention_bias_use_relative_scalar_only=FLAGS.
        attention_bias_use_relative_scalar_only,
    )
    estimator = experiment_utils.build_estimator(model_fn)

    if FLAGS.do_train:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        bert_config.to_json_file(
            os.path.join(FLAGS.model_dir, "bert_config.json"))
        train_input_fn = functools.partial(
            tapas_pretraining_model.input_fn,
            name="train",
            file_patterns=FLAGS.input_file_train,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=True,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq)
        estimator.train(input_fn=train_input_fn,
                        max_steps=experiment_utils.num_train_steps())

    if FLAGS.do_eval:
        eval_input_fn = functools.partial(
            tapas_pretraining_model.input_fn,
            name="eval",
            file_patterns=FLAGS.input_file_eval,
            data_format=FLAGS.data_format,
            compression_type=FLAGS.compression_type,
            is_training=False,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq)

        current_step = 0
        prev_checkpoint = None
        while True:
            checkpoint = estimator.latest_checkpoint()

            if checkpoint == prev_checkpoint:
                tf.logging.info("Sleeping 5 mins before evaluation")
                time.sleep(5 * 60)
                continue

            tf.logging.info("Running eval: %s", FLAGS.eval_name)
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=FLAGS.num_eval_steps,
                                        name=FLAGS.eval_name)
            tf.logging.info("Eval result:\n%s", result)

            current_step = int(os.path.basename(checkpoint).split("-")[1])
            if current_step >= experiment_utils.num_train_steps():
                tf.logging.info("Evaluation finished after training step %d",
                                current_step)
                break

            prev_checkpoint = checkpoint