def test_estimator_with_strategy_hooks(self, distribution,
                                           use_train_and_evaluate):
        config = run_config.RunConfig(eval_distribute=distribution)

        def _input_map_fn(tensor):
            return {'feature': tensor}, tensor

        def input_fn():
            return tf.data.Dataset.from_tensors(
                [1.]).repeat(10).batch(5).map(_input_map_fn)

        def model_fn(features, labels, mode):
            del features, labels
            global_step = tf.compat.v1.train.get_global_step()
            if mode == model_fn_lib.ModeKeys.TRAIN:
                train_hook1 = tf.compat.v1.train.StepCounterHook(
                    every_n_steps=1, output_dir=self.get_temp_dir())
                train_hook2 = tf.compat.v1.test.mock.MagicMock(
                    wraps=tf.compat.v1.train.SessionRunHook(),
                    spec=tf.compat.v1.train.SessionRunHook)
                return model_fn_lib.EstimatorSpec(
                    mode,
                    loss=tf.constant(1.),
                    train_op=global_step.assign_add(1),
                    training_hooks=[train_hook1, train_hook2])
            if mode == model_fn_lib.ModeKeys.EVAL:
                eval_hook1 = tf.compat.v1.train.StepCounterHook(
                    every_n_steps=1, output_dir=self.get_temp_dir())
                eval_hook2 = tf.compat.v1.test.mock.MagicMock(
                    wraps=tf.compat.v1.train.SessionRunHook(),
                    spec=tf.compat.v1.train.SessionRunHook)
                return model_fn_lib.EstimatorSpec(
                    mode=mode,
                    loss=tf.constant(1.),
                    evaluation_hooks=[eval_hook1, eval_hook2])

        num_steps = 10
        estimator = estimator_lib.EstimatorV2(model_fn=model_fn,
                                              model_dir=self.get_temp_dir(),
                                              config=config)
        if use_train_and_evaluate:
            training.train_and_evaluate(
                estimator, training.TrainSpec(input_fn, max_steps=num_steps),
                training.EvalSpec(input_fn))
        else:
            estimator.train(input_fn, steps=num_steps)
            estimator.evaluate(input_fn, steps=num_steps)
  def _make_estimator(self, model_dir):

    def _model_fn(features, labels, mode):
      del labels
      model = SubclassedModel()
      optimizer = adam.Adam(0.01)
      checkpoint = util.Checkpoint(
          step=tf.compat.v1.train.get_or_create_global_step(),
          optimizer=optimizer,
          model=model)
      # Make the save counter to satisfy the assert_consumed() assertion later
      checkpoint.save_counter  # pylint: disable=pointless-statement
      with tf.GradientTape() as tape:
        output = model(features['feature'])
        loss = tf.math.reduce_sum(output)
      variables = model.trainable_variables
      gradients = tape.gradient(loss, variables)
      train_op = tf.group(
          optimizer.apply_gradients(zip(gradients, variables)),
          checkpoint.step.assign_add(1))
      return model_fn_lib.EstimatorSpec(
          mode,
          loss=loss,
          train_op=train_op,
          predictions=dict(
              output=output,
              bias=tf.tile(
                  model.dense_two.bias[None, :],
                  [tf.compat.v1.shape(output)[0], 1]),
              step=tf.tile(
                  checkpoint.step[None],
                  [tf.compat.v1.shape(output)[0]])),
          scaffold=tf.compat.v1.train.Scaffold(saver=checkpoint)
      )

    est = estimator_lib.EstimatorV2(model_fn=_model_fn, model_dir=model_dir)

    def _input_map_fn(tensor):
      """Converts a tensor into `features, labels` format used by Estimator."""
      return {'feature': tensor}, tensor

    def _input_fn():
      return tf.compat.v1.data.Dataset.from_tensors(
          [1.]).repeat().batch(10).map(_input_map_fn)

    return est, _input_fn
Ejemplo n.º 3
0
  def _make_estimator(self, model_dir):

    def _model_fn(features, labels, mode):
      del labels
      model = SubclassedModel()
      optimizer = adam.Adam(0.01)
      checkpoint = util.Checkpoint(
          step=training_util.get_or_create_global_step(),
          optimizer=optimizer,
          model=model)
      # Make the save counter to satisfy the assert_consumed() assertion later
      checkpoint.save_counter  # pylint: disable=pointless-statement
      with backprop.GradientTape() as tape:
        output = model(features)
        loss = math_ops.reduce_sum(output)
      variables = model.trainable_variables
      gradients = tape.gradient(loss, variables)
      train_op = control_flow_ops.group(
          optimizer.apply_gradients(zip(gradients, variables)),
          checkpoint.step.assign_add(1))
      return model_fn_lib.EstimatorSpec(
          mode,
          loss=loss,
          train_op=train_op,
          predictions=dict(
              output=output,
              bias=array_ops.tile(
                  model.dense_two.bias[None, :],
                  [array_ops.shape(output)[0], 1]),
              step=array_ops.tile(
                  checkpoint.step[None],
                  [array_ops.shape(output)[0]])),
          scaffold=monitored_session.Scaffold(saver=checkpoint)
      )

    est = estimator_lib.EstimatorV2(model_fn=_model_fn, model_dir=model_dir)

    def _input_fn():
      return dataset_ops.Dataset.from_tensors([1.]).repeat().batch(10)

    return est, _input_fn