def _run_model(config, bert_config, dataset_fn):
    batch_size = config["batch_size"]
    model = nqg_model.Model(batch_size,
                            config,
                            bert_config,
                            training=True,
                            verbose=False)
    optimizer = optimization.create_optimizer(config["learning_rate"],
                                              config["training_steps"],
                                              config["warmup_steps"])

    training_step = training_utils.get_training_step(optimizer, model)
    mean_loss = training_step(next(iter(dataset_fn(ctx=None))))
    return mean_loss
Ejemplo n.º 2
0
 def __init__(self,
              tokenizer,
              rules,
              config,
              bert_config,
              target_grammar_rules=None,
              verbose=False):
     self.tokenizer = tokenizer
     self.config = config
     self.batch_size = 1
     self.model = nqg_model.Model(self.batch_size,
                                  config,
                                  bert_config,
                                  training=False)
     self.checkpoint = tf.train.Checkpoint(model=self.model)
     self.rules = rules
     self.target_grammar_rules = target_grammar_rules
     self.verbose = verbose
def _run_model_with_strategy(strategy, config, bert_config, dataset_fn):
    dataset_iterator = iter(
        strategy.experimental_distribute_datasets_from_function(dataset_fn))
    batch_size = int(config["batch_size"] / strategy.num_replicas_in_sync)
    with strategy.scope():
        model = nqg_model.Model(batch_size,
                                config,
                                bert_config,
                                training=True,
                                verbose=False)
        optimizer = optimization.create_optimizer(config["learning_rate"],
                                                  config["training_steps"],
                                                  config["warmup_steps"])
        train_for_n_steps_fn = training_utils.get_train_for_n_steps_fn(
            strategy, optimizer, model)
        mean_loss = train_for_n_steps_fn(
            dataset_iterator,
            tf.convert_to_tensor(config["steps_per_iteration"],
                                 dtype=tf.int32))
        return mean_loss
Ejemplo n.º 4
0
    def test_get_wordpiece_encodings(self):
        config = test_utils.get_test_config()
        batch_size = config["batch_size"]
        bert_config = test_utils.get_test_bert_config()
        model = nqg_model.Model(batch_size,
                                config,
                                bert_config,
                                training=True,
                                verbose=False)

        wordpiece_ids_batch = tf.constant(np.random.randint(
            0, bert_config.vocab_size,
            [batch_size, config["max_num_wordpieces"]]),
                                          dtype=tf.int32)

        num_wordpieces = tf.constant([[3]] * batch_size, dtype=tf.int32)

        wordpiece_encodings = model.get_wordpiece_encodings(
            wordpiece_ids_batch, num_wordpieces)
        print("wordpiece_encodings: %s" % wordpiece_encodings)
        self.assertEqual(wordpiece_encodings.shape,
                         (batch_size, config["max_num_wordpieces"],
                          bert_config.hidden_size))
Ejemplo n.º 5
0
def train_model(strategy):
    """Run model training."""
    config = config_utils.json_file_to_dict(FLAGS.config)
    dataset_fn = input_utils.get_dataset_fn(FLAGS.input, config)

    writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.model_dir, "train"))

    dataset_iterator = iter(
        strategy.experimental_distribute_datasets_from_function(dataset_fn))

    bert_config = configs.BertConfig.from_json_file(
        os.path.join(FLAGS.bert_dir, "bert_config.json"))
    logging.info("Loaded BERT config: %s", bert_config.to_dict())
    batch_size = int(config["batch_size"] / strategy.num_replicas_in_sync)
    logging.info("num_replicas: %s.", strategy.num_replicas_in_sync)
    logging.info("per replica batch_size: %s.", batch_size)

    with strategy.scope():
        model = nqg_model.Model(batch_size,
                                config,
                                bert_config,
                                training=True,
                                verbose=FLAGS.verbose)
        optimizer = optimization.create_optimizer(config["learning_rate"],
                                                  config["training_steps"],
                                                  config["warmup_steps"])
        train_for_n_steps_fn = training_utils.get_train_for_n_steps_fn(
            strategy, optimizer, model)

        if FLAGS.init_bert_checkpoint:
            bert_checkpoint = tf.train.Checkpoint(model=model.bert_encoder)
            bert_checkpoint_path = os.path.join(FLAGS.bert_dir,
                                                "bert_model.ckpt")
            logging.info("Restoring bert checkpoint: %s", bert_checkpoint_path)
            logging.info("Bert vars: %s",
                         model.bert_encoder.trainable_variables)
            logging.info("Checkpoint vars: %s",
                         tf.train.list_variables(bert_checkpoint_path))
            status = bert_checkpoint.restore(bert_checkpoint_path)
            status.assert_existing_objects_matched()

        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
        current_step = 0

        if FLAGS.restore_checkpoint:
            latest_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
            # TODO(petershaw): This is a hacky way to read current step.
            current_step = int(latest_checkpoint.split("-")[-2])
            logging.info("Restoring %s at step %s.", latest_checkpoint,
                         current_step)
            status = checkpoint.restore(latest_checkpoint)
            status.assert_existing_objects_matched()

        with writer.as_default():
            while current_step < config["training_steps"]:
                logging.info("current_step: %s.", current_step)
                mean_loss = train_for_n_steps_fn(
                    dataset_iterator,
                    tf.convert_to_tensor(config["steps_per_iteration"],
                                         dtype=tf.int32))
                tf.summary.scalar("loss", mean_loss, step=current_step)
                current_step += config["steps_per_iteration"]

                if current_step and current_step % config[
                        "save_checkpoint_every"] == 0:
                    checkpoint_prefix = os.path.join(FLAGS.model_dir,
                                                     "ckpt-%s" % current_step)
                    logging.info("Saving checkpoint to %s.", checkpoint_prefix)
                    checkpoint.save(file_prefix=checkpoint_prefix)