def test_create_model_not_train(self):
     model = transformer.create_model(self.params, False)
     inputs, outputs = model.inputs, model.outputs
     self.assertEqual(len(inputs), 1)
     self.assertEqual(len(outputs), 2)
     self.assertEqual(inputs[0].shape.as_list(), [None, None])
     self.assertEqual(inputs[0].dtype, tf.int64)
     self.assertEqual(outputs[0].shape.as_list(), [None, None])
     self.assertEqual(outputs[0].dtype, tf.int32)
     self.assertEqual(outputs[1].shape.as_list(), [None])
     self.assertEqual(outputs[1].dtype, tf.float32)
    def test_forward_pass_not_train(self):
        inputs = np.asarray([[5, 2, 1], [7, 5, 0], [1, 4, 0], [7, 5, 11]])

        # src_model is the original model before refactored.
        src_model = transformer.create_model(self.params, False)
        src_num_weights = _count_params(src_model)
        src_weights = src_model.get_weights()
        src_model_output = src_model([inputs], training=False)

        # dest_model is the refactored model.
        dest_model = _create_model(self.params, False)
        dest_num_weights = _count_params(dest_model)
        self.assertEqual(src_num_weights, dest_num_weights)
        dest_model.set_weights(src_weights)
        dest_model_output = dest_model([inputs], training=False)
        self.assertAllEqual(src_model_output[0], dest_model_output[0])
        self.assertAllEqual(src_model_output[1], dest_model_output[1])
Beispiel #3
0
  def eval(self):
    """Evaluates the model."""
    distribution_strategy = self.distribution_strategy if self.use_tpu else None

    # We only want to create the model under DS scope for TPU case.
    # When 'distribution_strategy' is None, a no-op DummyContextManager will
    # be used.
    with distribute_utils.get_strategy_scope(distribution_strategy):
      if not self.predict_model:
        self.predict_model = transformer.create_model(self.params, False)
      self._load_weights_if_possible(
          self.predict_model,
          tf.train.latest_checkpoint(self.flags_obj.model_dir))
      self.predict_model.summary()
    return evaluate_and_log_bleu(
        self.predict_model, self.params, self.flags_obj.bleu_source,
        self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
        distribution_strategy)
Beispiel #4
0
  def predict(self):
    """Predicts result from the model."""
    params = self.params
    flags_obj = self.flags_obj

    with tf.name_scope("model"):
      model = transformer.create_model(params, is_train=False)
      self._load_weights_if_possible(
          model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
      model.summary()
    subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file)

    ds = data_pipeline.eval_input_fn(params)
    ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE)
    ret = model.predict(ds)
    val_outputs, _ = ret
    length = len(val_outputs)
    for i in range(length):
      translate.translate_from_input(val_outputs[i], subtokenizer)
    def test_forward_pass_train(self):
        # Set input_len different from target_len
        inputs = np.asarray([[5, 2, 1], [7, 5, 0], [1, 4, 0], [7, 5, 11]])
        targets = np.asarray([[4, 3, 4, 0], [13, 19, 17, 8], [20, 14, 1, 2],
                              [5, 7, 3, 0]])

        # src_model is the original model before refactored.
        src_model = transformer.create_model(self.params, True)
        src_num_weights = _count_params(src_model)
        src_weights = src_model.get_weights()
        src_model_output = src_model([inputs, targets], training=True)

        # dest_model is the refactored model.
        dest_model = _create_model(self.params, True)
        dest_num_weights = _count_params(dest_model)
        self.assertEqual(src_num_weights, dest_num_weights)
        dest_model.set_weights(src_weights)
        dest_model_output = dest_model([inputs, targets], training=True)
        self.assertAllEqual(src_model_output, dest_model_output)
Beispiel #6
0
  def train(self):
    """Trains the model."""
    params = self.params
    flags_obj = self.flags_obj
    # Sets config options.
    keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)

    _ensure_dir(flags_obj.model_dir)
    with distribute_utils.get_strategy_scope(self.distribution_strategy):
      model = transformer.create_model(params, is_train=True)
      opt = self._create_optimizer()

      current_step = 0
      checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
      latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
      if latest_checkpoint:
        checkpoint.restore(latest_checkpoint)
        logging.info("Loaded checkpoint %s", latest_checkpoint)
        current_step = opt.iterations.numpy()

      if params["use_ctl"]:
        train_loss_metric = tf.keras.metrics.Mean(
            "training_loss", dtype=tf.float32)
        if params["enable_tensorboard"]:
          summary_writer = tf.summary.create_file_writer(
              os.path.join(flags_obj.model_dir, "summary"))
        else:
          summary_writer = tf.summary.create_noop_writer()
        train_metrics = [train_loss_metric]
        if params["enable_metrics_in_training"]:
          train_metrics = train_metrics + model.metrics
      else:
        model.compile(opt)

    model.summary()

    if self.use_tpu:
      # Different from experimental_distribute_dataset,
      # distribute_datasets_from_function requires
      # per-replica/local batch size.
      params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
      train_ds = (
          self.distribution_strategy.distribute_datasets_from_function(
              lambda ctx: data_pipeline.train_input_fn(params, ctx)))
    else:
      train_ds = data_pipeline.train_input_fn(params)
      map_data_fn = data_pipeline.map_data_for_transformer_fn
      train_ds = train_ds.map(
          map_data_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if params["use_ctl"]:
      train_ds_iterator = iter(train_ds)

    callbacks = self._create_callbacks(flags_obj.model_dir, params)

    # Only TimeHistory callback is supported for CTL
    if params["use_ctl"]:
      callbacks = [cb for cb in callbacks
                   if isinstance(cb, keras_utils.TimeHistory)]

    @tf.function
    def train_steps(iterator, steps):
      """Training steps function for TPU runs.

      Args:
        iterator: The input iterator of the training dataset.
        steps: An integer, the number of training steps.

      Returns:
        A float, the loss value.
      """

      def _step_fn(inputs):
        """Per-replica step function."""
        inputs, targets = inputs
        with tf.GradientTape() as tape:
          logits = model([inputs, targets], training=True)
          loss = metrics.transformer_loss(logits, targets,
                                          params["label_smoothing"],
                                          params["vocab_size"])
          # Scales the loss, which results in using the average loss across all
          # of the replicas for backprop.
          scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync

        # De-dupes variables due to keras tracking issues.
        tvars = list({id(v): v for v in model.trainable_variables}.values())
        grads = tape.gradient(scaled_loss, tvars)
        opt.apply_gradients(zip(grads, tvars))
        # For reporting, the metric takes the mean of losses.
        train_loss_metric.update_state(loss)

      for _ in tf.range(steps):
        train_loss_metric.reset_states()
        self.distribution_strategy.run(
            _step_fn, args=(next(iterator),))

    cased_score, uncased_score = None, None
    cased_score_history, uncased_score_history = [], []
    while current_step < flags_obj.train_steps:
      remaining_steps = flags_obj.train_steps - current_step
      train_steps_per_eval = (
          remaining_steps if remaining_steps < flags_obj.steps_between_evals
          else flags_obj.steps_between_evals)
      current_iteration = current_step // flags_obj.steps_between_evals

      logging.info(
          "Start train iteration at global step:{}".format(current_step))
      history = None
      if params["use_ctl"]:
        if not self.use_tpu:
          raise NotImplementedError(
              "Custom training loop on GPUs is not implemented.")

        # Runs training steps.
        with summary_writer.as_default():
          for cb in callbacks:
            cb.on_epoch_begin(current_iteration)
            cb.on_batch_begin(0)

          train_steps(
              train_ds_iterator,
              tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
          current_step += train_steps_per_eval
          train_loss = train_loss_metric.result().numpy().astype(float)
          logging.info("Train Step: %d/%d / loss = %s", current_step,
                       flags_obj.train_steps, train_loss)

          for cb in callbacks:
            cb.on_batch_end(train_steps_per_eval - 1)
            cb.on_epoch_end(current_iteration)

          if params["enable_tensorboard"]:
            for metric_obj in train_metrics:
              tf.summary.scalar(metric_obj.name, metric_obj.result(),
                                current_step)
              summary_writer.flush()

        for cb in callbacks:
          cb.on_train_end()

        if flags_obj.enable_checkpointing:
          # avoid check-pointing when running for benchmarking.
          checkpoint_name = checkpoint.save(
              os.path.join(flags_obj.model_dir,
                           "ctl_step_{}.ckpt".format(current_step)))
          logging.info("Saved checkpoint to %s", checkpoint_name)
      else:
        if self.use_tpu:
          raise NotImplementedError(
              "Keras model.fit on TPUs is not implemented.")
        history = model.fit(
            train_ds,
            initial_epoch=current_iteration,
            epochs=current_iteration + 1,
            steps_per_epoch=train_steps_per_eval,
            callbacks=callbacks,
            # If TimeHistory is enabled, progress bar would be messy. Increase
            # the verbose level to get rid of it.
            verbose=(2 if flags_obj.enable_time_history else 1))
        current_step += train_steps_per_eval
        logging.info("Train history: {}".format(history.history))

      logging.info("End train iteration at global step:{}".format(current_step))

      if (flags_obj.bleu_source and flags_obj.bleu_ref):
        uncased_score, cased_score = self.eval()
        cased_score_history.append([current_iteration + 1, cased_score])
        uncased_score_history.append([current_iteration + 1, uncased_score])

    stats = ({
        "loss": train_loss
    } if history is None else {})
    misc.update_stats(history, stats, callbacks)
    if uncased_score and cased_score:
      stats["bleu_uncased"] = uncased_score
      stats["bleu_cased"] = cased_score
      stats["bleu_uncased_history"] = uncased_score_history
      stats["bleu_cased_history"] = cased_score_history
    return stats