def _create_estimator(self, params): tf.logging.info("Setting random seed to {}".format(42)) np.random.seed(42) # Small bert model for testing. bert_config = modeling.BertConfig.from_dict({ "vocab_size": 10, "type_vocab_size": [3, 256, 256, 2, 256, 256, 10], "num_hidden_layers": 2, "num_attention_heads": 2, "hidden_size": 128, "intermediate_size": 512, }) model_fn = tapas_pretraining_model.model_fn_builder( bert_config=bert_config, init_checkpoint=params["init_checkpoint"], learning_rate=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=params["num_warmup_steps"], use_tpu=params["use_tpu"]) estimator = tf.estimator.tpu.TPUEstimator( use_tpu=params["use_tpu"], model_fn=model_fn, config=tf.estimator.tpu.RunConfig( model_dir=self.get_temp_dir(), save_summary_steps=params["num_train_steps"], save_checkpoints_steps=params["num_train_steps"]), train_batch_size=params["batch_size"], predict_batch_size=params["batch_size"], eval_batch_size=params["batch_size"]) return estimator
def main(_): bert_config = experiment_utils.bert_config_from_flags() model_fn = tapas_pretraining_model.model_fn_builder( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=experiment_utils.num_train_steps(), num_warmup_steps=experiment_utils.num_warmup_steps(), use_tpu=FLAGS.use_tpu, restrict_attention_mode=attention_utils.RestrictAttentionMode( FLAGS.restrict_attention_mode), restrict_attention_bucket_size=FLAGS.restrict_attention_bucket_size, restrict_attention_header_size=FLAGS.restrict_attention_header_size, restrict_attention_row_heads_ratio=( FLAGS.restrict_attention_row_heads_ratio), disabled_features=FLAGS.disabled_features, disable_position_embeddings=FLAGS.disable_position_embeddings, reset_position_index_per_cell=FLAGS.reset_position_index_per_cell, proj_value_length=FLAGS.proj_value_length if FLAGS.proj_value_length > 0 else None, attention_bias_disabled=FLAGS.attention_bias_disabled, attention_bias_use_relative_scalar_only=FLAGS. attention_bias_use_relative_scalar_only, ) estimator = experiment_utils.build_estimator(model_fn) if FLAGS.do_train: tf.io.gfile.makedirs(FLAGS.model_dir) bert_config.to_json_file( os.path.join(FLAGS.model_dir, "bert_config.json")) train_input_fn = functools.partial( tapas_pretraining_model.input_fn, name="train", file_patterns=FLAGS.input_file_train, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=True, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq) estimator.train(input_fn=train_input_fn, max_steps=experiment_utils.num_train_steps()) if FLAGS.do_eval: eval_input_fn = functools.partial( tapas_pretraining_model.input_fn, name="eval", file_patterns=FLAGS.input_file_eval, data_format=FLAGS.data_format, compression_type=FLAGS.compression_type, is_training=False, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq) current_step = 0 prev_checkpoint = None while True: checkpoint = estimator.latest_checkpoint() if checkpoint == prev_checkpoint: tf.logging.info("Sleeping 5 mins before evaluation") time.sleep(5 * 60) continue tf.logging.info("Running eval: %s", FLAGS.eval_name) result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, name=FLAGS.eval_name) tf.logging.info("Eval result:\n%s", result) current_step = int(os.path.basename(checkpoint).split("-")[1]) if current_step >= experiment_utils.num_train_steps(): tf.logging.info("Evaluation finished after training step %d", current_step) break prev_checkpoint = checkpoint