def test_init_from_checkpoint_global_step(self):
        """Tests that a simple model trains and exported models are valid."""
        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir,
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)
        # The model trains for 1000 steps and saves a checkpoint each 100 steps and
        # keeps 10 -> len == 10.
        self.assertLen(
            tf.io.gfile.glob(os.path.join(model_dir, 'model*.meta')), 10)

        # The continuous training has its own directory.
        continue_model_dir = self.create_tempdir().full_path
        init_from_checkpoint_fn = functools.partial(
            abstract_model.default_init_from_checkpoint_fn,
            checkpoint=model_dir)
        continue_mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            init_from_checkpoint_fn=init_from_checkpoint_fn)
        continue_mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)
        train_eval.train_eval_model(
            t2r_model=continue_mock_t2r_model,
            input_generator_train=continue_mock_input_generator_train,
            model_dir=continue_model_dir,
            max_train_steps=_MAX_TRAIN_STEPS + 100,
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)
        # If the model was successful restored including the global step, only 1
        # additional checkpoint to the init one should be created -> len == 2.
        self.assertLen(
            tf.io.gfile.glob(os.path.join(continue_model_dir, 'model*.meta')),
            2)
示例#2
0
    def test_predictor(self):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(
            t2r_model=mock_model,
            input_generator_train=input_generator,
            input_generator_eval=input_generator,
            max_train_steps=_MAX_TRAIN_STEPS,
            eval_steps=_MAX_EVAL_STEPS,
            model_dir=model_dir,
            create_exporters_fn=train_eval.create_default_exporters)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, 'export',
                                    'latest_exporter_numpy'))
        with self.assertRaises(ValueError):
            predictor.get_feature_specification()
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        with self.assertRaises(ValueError):
            _ = predictor.model_version
        self.assertTrue(predictor.restore())
        self.assertGreater(predictor.model_version, 0)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertEqual(predictions['logit'].shape, (2, 1))
示例#3
0
    def test_with_mock_training(self):
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            device_type='tpu',
            use_avg_model_params=True)

        mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
        export_dir = os.path.join(model_dir, _EXPORT_DIR)
        hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
            export_dir=export_dir,
            create_export_fn=async_export_hook_builder.default_create_export_fn
        )

        gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1')
        gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1')

        # We optimize our network.
        train_eval.train_eval_model(t2r_model=mock_t2r_model,
                                    input_generator_train=mock_input_generator,
                                    train_hook_builders=[hook_builder],
                                    model_dir=model_dir,
                                    max_train_steps=_MAX_STEPS)
        self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
        self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
        for exported_model_dir in tf.io.gfile.listdir(export_dir):
            self.assertNotEmpty(
                tf.io.gfile.listdir(
                    os.path.join(export_dir, exported_model_dir)))
        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=export_dir)
        self.assertTrue(predictor.restore())
示例#4
0
 def test_predictor_load_final_model(self):
     input_generator = default_input_generator.DefaultRandomInputGenerator(
         batch_size=_BATCH_SIZE)
     model_dir = self.create_tempdir().full_path
     mock_model = mocks.MockT2RModel()
     train_eval.train_eval_model(
         t2r_model=mock_model,
         input_generator_train=input_generator,
         input_generator_eval=input_generator,
         max_train_steps=_MAX_TRAIN_STEPS,
         eval_steps=_MAX_EVAL_STEPS,
         model_dir=model_dir,
         create_exporters_fn=train_eval.create_default_exporters)
     export_dir = os.path.join(model_dir, 'export', 'latest_exporter_numpy')
     final_export_dir = sorted(tf.io.gfile.glob(
         os.path.join(export_dir, '*')),
                               reverse=True)[0]
     predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
         export_dir=final_export_dir)
     predictor.restore()
     self.assertGreater(predictor.model_version, 0)
     self.assertEqual(predictor.global_step, 3)
     ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
         tf.estimator.ModeKeys.PREDICT)
     tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                   ref_feature_spec)
示例#5
0
    def test_predictor_init_with_default_exporter(self, restore_model_option):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(
            t2r_model=mock_model,
            input_generator_train=input_generator,
            input_generator_eval=input_generator,
            max_train_steps=_MAX_TRAIN_STEPS,
            eval_steps=_MAX_EVAL_STEPS,
            model_dir=model_dir,
            create_exporters_fn=train_eval.create_default_exporters)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, 'export',
                                    'latest_exporter_numpy'),
            restore_model_option=restore_model_option)
        if restore_model_option == exported_savedmodel_predictor.RestoreOptions.RESTORE_SYNCHRONOUSLY:
            predictor.restore()
        self.assertGreater(predictor.model_version, 0)
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
示例#6
0
  def test_with_mock_training(self):
    model_dir = self.create_tempdir().full_path
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor, device_type='cpu')

    mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)
    export_dir = os.path.join(model_dir, _EXPORT_DIR)
    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)
    hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
        export_dir=export_dir, create_export_fn=default_create_export_fn)

    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)

    # We optimize our network.
    train_eval.train_eval_model(
        t2r_model=mock_t2r_model,
        input_generator_train=mock_input_generator,
        train_hook_builders=[hook_builder],
        model_dir=model_dir,
        max_train_steps=_MAX_STEPS)
    self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
    self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
    for exported_model_dir in tf.io.gfile.listdir(export_dir):
      self.assertNotEmpty(
          tf.io.gfile.listdir(os.path.join(export_dir, exported_model_dir)))
    predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
        export_dir=export_dir)
    self.assertTrue(predictor.restore())
    def _train_and_eval_reference_model(self, path):
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

        # We create a tpu estimator for potential training.
        estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=mock_t2r_model.model_fn,
            use_tpu=mock_t2r_model.is_device_tpu,
            config=tf.contrib.tpu.RunConfig(model_dir=model_dir),
            train_batch_size=BATCH_SIZE,
            eval_batch_size=BATCH_SIZE)

        mock_input_generator = mocks.MockInputGenerator(batch_size=BATCH_SIZE)
        mock_input_generator.set_specification_from_model(
            mock_t2r_model, tf.estimator.ModeKeys.TRAIN)

        # We optimize our network.
        estimator.train(input_fn=mock_input_generator.create_dataset_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN),
                        max_steps=MAX_STEPS)

        # Verify that the serving estimator does exactly the same as the normal
        # estimator with all the parameters.
        estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=model_dir))

        prediction_ref = estimator_predict.predict(
            input_fn=mock_input_generator.create_dataset_input_fn(
                mode=tf.estimator.ModeKeys.EVAL))
        return model_dir, mock_t2r_model, prediction_ref
示例#8
0
    def test_predictor(self):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(t2r_model=mock_model,
                                    input_generator_train=input_generator,
                                    max_train_steps=_MAX_TRAIN_STEPS,
                                    model_dir=model_dir)

        predictor = checkpoint_predictor.CheckpointPredictor(
            t2r_model=mock_model, checkpoint_dir=model_dir, use_gpu=False)
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        self.assertEqual(predictor.model_version, -1)
        self.assertEqual(predictor.global_step, -1)
        self.assertTrue(predictor.restore())
        self.assertGreater(predictor.model_version, 0)
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertCountEqual(sorted(predictions.keys()), ['logit'])
        self.assertEqual(predictions['logit'].shape, (2, 1))
示例#9
0
 def test_predictor_timeout(self):
     mock_model = mocks.MockT2RModel()
     predictor = checkpoint_predictor.CheckpointPredictor(
         t2r_model=mock_model,
         checkpoint_dir='/random/path/which/does/not/exist',
         timeout=1)
     self.assertFalse(predictor.restore())
示例#10
0
 def test_predictor_raises(self):
     mock_model = mocks.MockT2RModel()
     # Raises because no checkpoint_dir and has been set and restore is called.
     predictor = checkpoint_predictor.CheckpointPredictor(
         t2r_model=mock_model)
     with self.assertRaises(ValueError):
         predictor.restore()
示例#11
0
  def test_maml_model(self, num_inner_loop_steps):
    model_dir = os.path.join(FLAGS.test_tmpdir, str(num_inner_loop_steps))
    gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                       _MAX_STEPS // 2)
    if tf.io.gfile.exists(model_dir):
      tf.io.gfile.rmtree(model_dir)

    mock_base_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

    mock_tf_model = MockMAMLModel(
        base_model=mock_base_model, num_inner_loop_steps=num_inner_loop_steps)

    # Note, we by choice use the same amount of conditioning samples for
    # inference as well during train and change the model for eval/inference
    # to only produce one output sample.
    mock_input_generator_train = MockMetaInputGenerator(
        batch_size=_BATCH_SIZE,
        num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK,
        num_inference_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK)
    mock_input_generator_train.set_specification_from_model(
        mock_tf_model, mode=tf.estimator.ModeKeys.TRAIN)

    mock_input_generator_eval = MockMetaInputGenerator(
        batch_size=_BATCH_SIZE,
        num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK,
        num_inference_samples_per_task=1)
    mock_input_generator_eval.set_specification_from_model(
        mock_tf_model, mode=tf.estimator.ModeKeys.TRAIN)
    mock_export_generator = MockMetaExportGenerator(
        num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK,
        num_inference_samples_per_task=1)

    train_eval.train_eval_model(
        t2r_model=mock_tf_model,
        input_generator_train=mock_input_generator_train,
        input_generator_eval=mock_input_generator_eval,
        max_train_steps=_MAX_STEPS,
        model_dir=model_dir,
        export_generator=mock_export_generator,
        create_exporters_fn=train_eval.create_default_exporters)
    export_dir = os.path.join(model_dir, 'export')
    # best_exporter_numpy, best_exporter_tf_example.
    self.assertLen(tf.io.gfile.glob(os.path.join(export_dir, '*')), 4)
    numpy_predictor_fn = contrib_predictor.from_saved_model(
        tf.io.gfile.glob(os.path.join(export_dir, 'best_exporter_numpy',
                                      '*'))[-1])

    feed_tensor_keys = sorted(numpy_predictor_fn.feed_tensors.keys())
    self.assertCountEqual(
        ['condition/features/x', 'condition/labels/y', 'inference/features/x'],
        feed_tensor_keys,
    )

    tf_example_predictor_fn = contrib_predictor.from_saved_model(
        tf.io.gfile.glob(
            os.path.join(export_dir, 'best_exporter_tf_example', '*'))[-1])
    self.assertCountEqual(['input_example_tensor'],
                          list(tf_example_predictor_fn.feed_tensors.keys()))
  def test_create_warmup_requests_numpy(self):
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor)
    exporter = mocks.MockExportGenerator()
    exporter.set_specification_from_model(mock_t2r_model)

    export_dir = self.create_tempdir()
    batch_sizes = [2, 4]
    request_filename = exporter.create_warmup_requests_numpy(
        batch_sizes=batch_sizes, export_dir=export_dir.full_path)

    for expected_batch_size, record in zip(
        batch_sizes, tf.compat.v1.io.tf_record_iterator(request_filename)):
      record_proto = prediction_log_pb2.PredictionLog()
      record_proto.ParseFromString(record)
      request = record_proto.predict_log.request
      self.assertEqual(request.model_spec.name, 'MockT2RModel')
      for _, in_tensor in request.inputs.items():
        self.assertEqual(in_tensor.tensor_shape.dim[0].size,
                         expected_batch_size)
示例#13
0
    def test_freezing_some_variables(self):
        """Tests we can freeze training for parts of the network."""
        def freeze_biases(var):
            # Update all variables except bias variables.
            return 'bias' not in var.name

        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('create_train_op.filter_trainables_fn',
                           freeze_biases)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)
        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir)

        start_checkpoint = tf.train.NewCheckpointReader(
            os.path.join(model_dir, 'model.ckpt-0'))
        last_checkpoint = tf.train.NewCheckpointReader(
            tf.train.latest_checkpoint(model_dir))
        for var_name, _ in tf.train.list_variables(model_dir):
            # Some of the batch norm moving averages are constant over training on the
            # mock data used.
            if 'batch_norm' in var_name:
                continue
            if 'bias' not in var_name:
                # Should update.
                self.assertNotAllClose(start_checkpoint.get_tensor(var_name),
                                       last_checkpoint.get_tensor(var_name),
                                       atol=1e-3)
            else:
                # Should not update.
                self.assertAllClose(start_checkpoint.get_tensor(var_name),
                                    last_checkpoint.get_tensor(var_name),
                                    atol=1e-3)
示例#14
0
  def test_hooks(self, mock_create_warmup_requests_numpy,
                 mock_create_serving_input_receiver_numpy_fn,
                 mock_checkpoint_init, mock_export_saved_model):

    def _checkpoint_init(export_fn, export_dir, **kwargs):
      del kwargs
      export_fn(export_dir, global_step=1)
      return None

    mock_checkpoint_init.side_effect = _checkpoint_init

    export_generator = mocks.MockExportGenerator()

    hook_builder = td3.TD3Hooks(
        export_dir=_EXPORT_DIR,
        lagged_export_dir=_LAGGED_EXPORT_DIR,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT,
        export_generator=export_generator)

    model = mocks.MockT2RModel()
    estimator = MockEstimator()

    mock_create_warmup_requests_numpy.return_value = _NUMPY_WARMUP_REQUESTS

    hooks = hook_builder.create_hooks(t2r_model=model, estimator=estimator)
    self.assertLen(hooks, 1)

    mock_create_warmup_requests_numpy.assert_called_with(
        batch_sizes=_BATCH_SIZES_FOR_EXPORT,
        export_dir=_MODEL_DIR)

    mock_export_saved_model.assert_called_with(
        serving_input_receiver_fn=mock.ANY,
        export_dir_base=_EXPORT_DIR,
        assets_extra={
            "tf_serving_warmup_requests": _NUMPY_WARMUP_REQUESTS,
            tensorspec_utils.T2R_ASSETS_FILENAME: mock.ANY
        })

    mock_create_serving_input_receiver_numpy_fn.assert_called()
    def test_predictor_with_async_hook(self):
        model_dir = self.create_tempdir().full_path
        default_create_export_fn = functools.partial(
            async_export_hook_builder.default_create_export_fn,
            batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)
        export_dir = os.path.join(model_dir, _EXPORT_DIR)
        hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
            export_dir=export_dir, create_export_fn=default_create_export_fn)
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(t2r_model=mock_model,
                                    input_generator_train=input_generator,
                                    train_hook_builders=[hook_builder],
                                    max_train_steps=_MAX_TRAIN_STEPS,
                                    model_dir=model_dir)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, _EXPORT_DIR))
        with self.assertRaises(ValueError):
            predictor.get_feature_specification()
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        with self.assertRaises(ValueError):
            _ = predictor.model_version
        self.assertEqual(predictor.global_step, -1)
        self.assertTrue(predictor.restore())
        self.assertGreater(predictor.model_version, 0)
        # NOTE: The async hook builder will export the global step.
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertCountEqual(sorted(predictions.keys()), ['logit'])
        self.assertEqual(predictions['logit'].shape, (2, 1))
示例#16
0
  def test_init_from_checkpoint_use_avg_model_params(self):
    """Tests that a simple model trains and exported models are valid."""
    gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100)
    gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10)
    model_dir = self.create_tempdir().full_path
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
        use_avg_model_params=True)

    mock_input_generator_train = mocks.MockInputGenerator(
        batch_size=_BATCH_SIZE)

    mock_input_generator = mocks.MockInputGenerator(batch_size=1)
    mock_input_generator.set_specification_from_model(
        mock_t2r_model, tf.estimator.ModeKeys.TRAIN)

    train_eval.train_eval_model(
        t2r_model=mock_t2r_model,
        input_generator_train=mock_input_generator_train,
        max_train_steps=_MAX_TRAIN_STEPS,
        model_dir=model_dir)

    # Verify that the serving estimator does exactly the same as the normal
    # estimator with all the parameters.
    initial_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn,
        config=tf.estimator.RunConfig(model_dir=model_dir))

    # pylint: disable=g-complex-comprehension
    initial_predictions = [
        prediction['logit'] for prediction in list(
            initial_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]

    # The continuous training has its own directory.
    continue_model_dir = self.create_tempdir().full_path
    init_from_checkpoint_fn = functools.partial(
        abstract_model.default_init_from_checkpoint_fn, checkpoint=model_dir)
    continue_mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
        init_from_checkpoint_fn=init_from_checkpoint_fn)
    continue_mock_input_generator_train = mocks.MockInputGenerator(
        batch_size=_BATCH_SIZE)
    # Re-initialize the model and train for one step, basically the same
    # performance as the original model.
    train_eval.train_eval_model(
        t2r_model=continue_mock_t2r_model,
        input_generator_train=continue_mock_input_generator_train,
        model_dir=continue_model_dir,
        max_train_steps=_MAX_TRAIN_STEPS + 1)

    # Verify that the serving estimator does exactly the same as the normal
    # estimator with all the parameters.
    continue_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn,
        config=tf.estimator.RunConfig(model_dir=continue_model_dir))

    continue_predictions = [
        prediction['logit'] for prediction in list(
            continue_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]

    self.assertTrue(
        np.allclose(initial_predictions, continue_predictions, atol=1e-2))

    # A randomly initialized model estimator with all the parameters.
    random_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn)

    random_predictions = [
        prediction['logit'] for prediction in list(
            random_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]
    self.assertFalse(
        np.allclose(initial_predictions, random_predictions, atol=1e-2))
    def test_train_eval_model(self):
        """Tests that a simple model trains and exported models are valid."""
        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)
        mock_input_generator_eval = mocks.MockInputGenerator(batch_size=1)
        fake_hook_builder = FakeHookBuilder()

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            input_generator_eval=mock_input_generator_eval,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir,
            train_hook_builders=[fake_hook_builder],
            eval_hook_builders=[fake_hook_builder],
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)

        self.assertTrue(fake_hook_builder.hook_mock.begin.called)

        # We ensure that both numpy and tf_example inference models are exported.
        best_exporter_numpy_path = os.path.join(model_dir, 'export',
                                                'best_exporter_numpy', '*')
        numpy_model_paths = sorted(tf.io.gfile.glob(best_exporter_numpy_path))
        # This mock network converges nicely which is why we have several best
        # models, by default we keep the best 5 and the latest one is always the
        # best.
        self.assertLen(numpy_model_paths, 5)

        best_exporter_tf_example_path = os.path.join(
            model_dir, 'export', 'best_exporter_tf_example', '*')

        tf_example_model_paths = sorted(
            tf.io.gfile.glob(best_exporter_tf_example_path))
        # This mock network converges nicely which is why we have several best
        # models, by default we keep the best 5 and the latest one is always the
        # best.
        self.assertLen(tf_example_model_paths, 5)

        # We test both saved models within one test since the bulk of the time
        # is spent training the model in the firstplace.

        # Verify that the serving estimator does exactly the same as the normal
        # estimator with all the parameters.
        estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=model_dir))

        prediction_ref = estimator_predict.predict(
            input_fn=mock_input_generator_eval.create_dataset_input_fn(
                mode=tf.estimator.ModeKeys.EVAL))

        # Now we can load our exported estimator graph with the numpy feed_dict
        # interface, there are no dependencies on the model_fn or preprocessor
        # anymore.
        # We load the latest model since it had the best eval performance.
        numpy_predictor_fn = tf.contrib.predictor.from_saved_model(
            numpy_model_paths[-1])

        features, labels = mock_input_generator_eval.create_numpy_data()

        ref_error = self._compute_total_loss(
            labels, [val['logit'].flatten() for val in prediction_ref])

        numpy_predictions = []
        for feature, label in zip(features, labels):
            predicted = numpy_predictor_fn({'x': feature.reshape(1, -1)
                                            })['logit'].flatten()
            numpy_predictions.append(predicted)
            # This ensures that we actually achieve perfect classification.
            if label > 0:
                self.assertGreater(predicted[0], 0)
            else:
                self.assertLess(predicted[0], 0)
        numpy_error = self._compute_total_loss(labels, numpy_predictions)

        # Now we can load our exported estimator graph with the tf_example feed_dict
        # interface, there are no dependencies on the model_fn or preprocessor
        # anymore.
        # We load the latest model since it had the best eval performance.
        tf_example_predictor_fn = tf.contrib.predictor.from_saved_model(
            tf_example_model_paths[-1])
        tf_example_predictions = []
        for feature, label in zip(features, labels):
            # We have to create our serialized tf.Example proto.
            example = tf.train.Example()
            example.features.feature[
                'measured_position'].float_list.value.extend(feature)
            feed_dict = {
                'input_example_tensor':
                np.array(example.SerializeToString()).reshape(1, )
            }
            predicted = tf_example_predictor_fn(feed_dict)['logit'].flatten()
            tf_example_predictions.append(predicted)
            # This ensures that we actually achieve perfect classification.
            if label > 0:
                self.assertGreater(predicted[0], 0)
            else:
                self.assertLess(predicted[0], 0)
        tf_example_error = self._compute_total_loss(labels,
                                                    tf_example_predictions)

        np.testing.assert_almost_equal(tf_example_error, numpy_error)
        # The exported saved models both have to have the same performance and since
        # we train on eval on the same fixed dataset the latest and greatest
        # model error should also be the best.
        np.testing.assert_almost_equal(ref_error, tf_example_error)
    def test_init_from_checkpoint_use_avg_model_params_and_weights(self):
        """Tests that a simple model trains and exported models are valid."""
        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            use_avg_model_params=True)

        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        mock_input_generator = mocks.MockInputGenerator(batch_size=1)
        mock_input_generator.set_specification_from_model(
            mock_t2r_model, tf.estimator.ModeKeys.TRAIN)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir)

        init_checkpoint = tf.train.NewCheckpointReader(
            tf.train.latest_checkpoint(model_dir))

        # Verify that the serving estimator does exactly the same as the normal
        # estimator with all the parameters.
        initial_estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=model_dir))

        # pylint: disable=g-complex-comprehension
        initial_predictions = [
            prediction['logit'] for prediction in list(
                initial_estimator_predict.predict(
                    input_fn=mock_input_generator.create_dataset_input_fn(
                        mode=tf.estimator.ModeKeys.EVAL)))
        ]

        # The continuous training has its own directory.
        continue_model_dir = self.create_tempdir().full_path
        init_from_checkpoint_fn = functools.partial(
            abstract_model.default_init_from_checkpoint_fn,
            checkpoint=model_dir)
        continue_mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            init_from_checkpoint_fn=init_from_checkpoint_fn)
        continue_mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)
        # Re-initialize the model and train for one step, basically the same
        # performance as the original model.
        train_eval.train_eval_model(
            t2r_model=continue_mock_t2r_model,
            input_generator_train=continue_mock_input_generator_train,
            model_dir=continue_model_dir,
            max_train_steps=_MAX_TRAIN_STEPS)

        continue_checkpoint = tf.train.NewCheckpointReader(
            tf.train.latest_checkpoint(continue_model_dir))

        for tensor_name, _ in tf.train.list_variables(model_dir):
            if 'ExponentialMovingAverage' in tensor_name:
                # These values are replaced by the swapping saver when using the
                # use_avg_model_params.
                continue
            if 'Adam' in tensor_name:
                # The adam optimizer values are not required.
                continue
            if 'global_step' in tensor_name:
                # The global step will be incremented by 1.
                continue
            self.assertAllClose(init_checkpoint.get_tensor(tensor_name),
                                continue_checkpoint.get_tensor(tensor_name),
                                atol=1e-3)

        # Verify that the serving estimator does exactly the same as the normal
        # estimator with all the parameters.
        continue_estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=continue_model_dir))

        continue_predictions = [
            prediction['logit'] for prediction in list(
                continue_estimator_predict.predict(
                    input_fn=mock_input_generator.create_dataset_input_fn(
                        mode=tf.estimator.ModeKeys.EVAL)))
        ]

        self.assertTrue(
            np.allclose(initial_predictions, continue_predictions, atol=1e-2))

        # A randomly initialized model estimator with all the parameters.
        random_estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn)

        random_predictions = [
            prediction['logit'] for prediction in list(
                random_estimator_predict.predict(
                    input_fn=mock_input_generator.create_dataset_input_fn(
                        mode=tf.estimator.ModeKeys.EVAL)))
        ]
        self.assertFalse(
            np.allclose(initial_predictions, random_predictions, atol=1e-2))