Exemplo n.º 1
0
    def test_init_from_checkpoint_global_step(self):
        """Tests that a simple model trains and exported models are valid."""
        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir,
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)
        # The model trains for 1000 steps and saves a checkpoint each 100 steps and
        # keeps 10 -> len == 10.
        self.assertLen(
            tf.io.gfile.glob(os.path.join(model_dir, 'model*.meta')), 10)

        # The continuous training has its own directory.
        continue_model_dir = self.create_tempdir().full_path
        init_from_checkpoint_fn = functools.partial(
            abstract_model.default_init_from_checkpoint_fn,
            checkpoint=model_dir)
        continue_mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            init_from_checkpoint_fn=init_from_checkpoint_fn)
        continue_mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)
        train_eval.train_eval_model(
            t2r_model=continue_mock_t2r_model,
            input_generator_train=continue_mock_input_generator_train,
            model_dir=continue_model_dir,
            max_train_steps=_MAX_TRAIN_STEPS + 100,
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)
        # If the model was successful restored including the global step, only 1
        # additional checkpoint to the init one should be created -> len == 2.
        self.assertLen(
            tf.io.gfile.glob(os.path.join(continue_model_dir, 'model*.meta')),
            2)
    def test_create_serving_input_receiver_numpy(self):
        (model_dir, mock_t2r_model,
         prediction_ref) = self._train_and_eval_reference_model('numpy')
        exporter = default_export_generator.DefaultExportGenerator()
        exporter.set_specification_from_model(mock_t2r_model)

        # Export trained serving estimator.
        estimator_exporter = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=model_dir))

        serving_input_receiver_fn = (
            exporter.create_serving_input_receiver_numpy_fn())
        exported_savedmodel_path = estimator_exporter.export_saved_model(
            export_dir_base=model_dir,
            serving_input_receiver_fn=serving_input_receiver_fn,
            checkpoint_path=tf.train.latest_checkpoint(model_dir))

        # Load trained and exported serving estimator, run prediction and assert
        # it is the same as before exporting.
        feed_predictor_fn = tf.contrib.predictor.from_saved_model(
            exported_savedmodel_path)
        mock_input_generator = mocks.MockInputGenerator(batch_size=BATCH_SIZE)
        features, labels = mock_input_generator.create_numpy_data()
        for pos, value in enumerate(prediction_ref):
            actual = feed_predictor_fn({'x': features[pos, :].reshape(1, -1)
                                        })['logit'].flatten()
            predicted = value['logit'].flatten()
            np.testing.assert_almost_equal(actual=actual,
                                           desired=predicted,
                                           decimal=4)
            if labels[pos] > 0:
                self.assertGreater(predicted[0], 0)
            else:
                self.assertLess(predicted[0], 0)
    def _train_and_eval_reference_model(self, path):
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

        # We create a tpu estimator for potential training.
        estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=mock_t2r_model.model_fn,
            use_tpu=mock_t2r_model.is_device_tpu,
            config=tf.contrib.tpu.RunConfig(model_dir=model_dir),
            train_batch_size=BATCH_SIZE,
            eval_batch_size=BATCH_SIZE)

        mock_input_generator = mocks.MockInputGenerator(batch_size=BATCH_SIZE)
        mock_input_generator.set_specification_from_model(
            mock_t2r_model, tf.estimator.ModeKeys.TRAIN)

        # We optimize our network.
        estimator.train(input_fn=mock_input_generator.create_dataset_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN),
                        max_steps=MAX_STEPS)

        # Verify that the serving estimator does exactly the same as the normal
        # estimator with all the parameters.
        estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=model_dir))

        prediction_ref = estimator_predict.predict(
            input_fn=mock_input_generator.create_dataset_input_fn(
                mode=tf.estimator.ModeKeys.EVAL))
        return model_dir, mock_t2r_model, prediction_ref
Exemplo n.º 4
0
    def test_with_mock_training(self):
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            device_type='tpu',
            use_avg_model_params=True)

        mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
        export_dir = os.path.join(model_dir, _EXPORT_DIR)
        hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
            export_dir=export_dir,
            create_export_fn=async_export_hook_builder.default_create_export_fn
        )

        gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1')
        gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1')

        # We optimize our network.
        train_eval.train_eval_model(t2r_model=mock_t2r_model,
                                    input_generator_train=mock_input_generator,
                                    train_hook_builders=[hook_builder],
                                    model_dir=model_dir,
                                    max_train_steps=_MAX_STEPS)
        self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
        self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
        for exported_model_dir in tf.io.gfile.listdir(export_dir):
            self.assertNotEmpty(
                tf.io.gfile.listdir(
                    os.path.join(export_dir, exported_model_dir)))
        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=export_dir)
        self.assertTrue(predictor.restore())
Exemplo n.º 5
0
  def test_with_mock_training(self):
    model_dir = self.create_tempdir().full_path
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor, device_type='cpu')

    mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)
    export_dir = os.path.join(model_dir, _EXPORT_DIR)
    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)
    hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
        export_dir=export_dir, create_export_fn=default_create_export_fn)

    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)

    # We optimize our network.
    train_eval.train_eval_model(
        t2r_model=mock_t2r_model,
        input_generator_train=mock_input_generator,
        train_hook_builders=[hook_builder],
        model_dir=model_dir,
        max_train_steps=_MAX_STEPS)
    self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
    self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
    for exported_model_dir in tf.io.gfile.listdir(export_dir):
      self.assertNotEmpty(
          tf.io.gfile.listdir(os.path.join(export_dir, exported_model_dir)))
    predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
        export_dir=export_dir)
    self.assertTrue(predictor.restore())
Exemplo n.º 6
0
  def test_set_preprocess_fn(self):
    mock_input_generator = mocks.MockInputGenerator(batch_size=BATCH_SIZE)
    preprocessor = noop_preprocessor.NoOpPreprocessor()
    with self.assertRaises(ValueError):
      # This should raise since we pass a function with `mode` not already
      # filled in either by a closure or functools.partial.
      mock_input_generator.set_preprocess_fn(preprocessor.preprocess)

    preprocess_fn = functools.partial(preprocessor.preprocess, labels=None)
    with self.assertRaises(ValueError):
      # This should raise since we pass a partial function but `mode`
      # is not abstracted away.
      mock_input_generator.set_preprocess_fn(preprocess_fn)
    def test_create_serving_input_receiver_tf_example(self, multi_dataset):
        (model_dir, mock_t2r_model,
         prediction_ref) = self._train_and_eval_reference_model(
             'tf_example', multi_dataset=multi_dataset)

        # Now we can actually export our serving estimator.
        estimator_exporter = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=model_dir))

        exporter = default_export_generator.DefaultExportGenerator()
        exporter.set_specification_from_model(mock_t2r_model)
        serving_input_receiver_fn = (
            exporter.create_serving_input_receiver_tf_example_fn())
        exported_savedmodel_path = estimator_exporter.export_saved_model(
            export_dir_base=model_dir,
            serving_input_receiver_fn=serving_input_receiver_fn,
            checkpoint_path=tf.train.latest_checkpoint(model_dir))

        # Now we can load our exported estimator graph, there are no dependencies
        # on the model_fn or preprocessor anymore.
        feed_predictor_fn = tf.contrib.predictor.from_saved_model(
            exported_savedmodel_path)
        mock_input_generator = mocks.MockInputGenerator(batch_size=BATCH_SIZE)
        features, labels = mock_input_generator.create_numpy_data()
        for pos, value in enumerate(prediction_ref):
            # We have to create our serialized tf.Example proto.
            example = tf.train.Example()
            example.features.feature[
                'measured_position'].float_list.value.extend(features[pos])
            serialized_example = np.array(example.SerializeToString()).reshape(
                1, )
            if multi_dataset:
                feed_dict = {
                    'input_example_dataset1': serialized_example,
                    'input_example_dataset2': serialized_example
                }
            else:
                feed_dict = {'input_example_tensor': serialized_example}
            actual = feed_predictor_fn(feed_dict)['logit'].flatten()
            predicted = value['logit'].flatten()
            np.testing.assert_almost_equal(actual=actual,
                                           desired=predicted,
                                           decimal=4)
            if labels[pos] > 0:
                self.assertGreater(predicted[0], 0)
            else:
                self.assertLess(predicted[0], 0)
Exemplo n.º 8
0
    def test_freezing_some_variables(self):
        """Tests we can freeze training for parts of the network."""
        def freeze_biases(var):
            # Update all variables except bias variables.
            return 'bias' not in var.name

        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('create_train_op.filter_trainables_fn',
                           freeze_biases)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)
        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir)

        start_checkpoint = tf.train.NewCheckpointReader(
            os.path.join(model_dir, 'model.ckpt-0'))
        last_checkpoint = tf.train.NewCheckpointReader(
            tf.train.latest_checkpoint(model_dir))
        for var_name, _ in tf.train.list_variables(model_dir):
            # Some of the batch norm moving averages are constant over training on the
            # mock data used.
            if 'batch_norm' in var_name:
                continue
            if 'bias' not in var_name:
                # Should update.
                self.assertNotAllClose(start_checkpoint.get_tensor(var_name),
                                       last_checkpoint.get_tensor(var_name),
                                       atol=1e-3)
            else:
                # Should not update.
                self.assertAllClose(start_checkpoint.get_tensor(var_name),
                                    last_checkpoint.get_tensor(var_name),
                                    atol=1e-3)
Exemplo n.º 9
0
    def test_train_eval_model(self):
        """Tests that a simple model trains and exported models are valid."""
        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)
        mock_input_generator_eval = mocks.MockInputGenerator(batch_size=1)
        fake_hook_builder = FakeHookBuilder()

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            input_generator_eval=mock_input_generator_eval,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir,
            train_hook_builders=[fake_hook_builder],
            eval_hook_builders=[fake_hook_builder],
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)

        self.assertTrue(fake_hook_builder.hook_mock.begin.called)

        # We ensure that both numpy and tf_example inference models are exported.
        best_exporter_numpy_path = os.path.join(model_dir, 'export',
                                                'best_exporter_numpy', '*')
        numpy_model_paths = sorted(tf.io.gfile.glob(best_exporter_numpy_path))
        # This mock network converges nicely which is why we have several best
        # models, by default we keep the best 5 and the latest one is always the
        # best.
        self.assertLen(numpy_model_paths, 5)

        best_exporter_tf_example_path = os.path.join(
            model_dir, 'export', 'best_exporter_tf_example', '*')

        tf_example_model_paths = sorted(
            tf.io.gfile.glob(best_exporter_tf_example_path))
        # This mock network converges nicely which is why we have several best
        # models, by default we keep the best 5 and the latest one is always the
        # best.
        self.assertLen(tf_example_model_paths, 5)

        # We test both saved models within one test since the bulk of the time
        # is spent training the model in the firstplace.

        # Verify that the serving estimator does exactly the same as the normal
        # estimator with all the parameters.
        estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=model_dir))

        prediction_ref = estimator_predict.predict(
            input_fn=mock_input_generator_eval.create_dataset_input_fn(
                mode=tf.estimator.ModeKeys.EVAL))

        # Now we can load our exported estimator graph with the numpy feed_dict
        # interface, there are no dependencies on the model_fn or preprocessor
        # anymore.
        # We load the latest model since it had the best eval performance.
        numpy_predictor_fn = tf.contrib.predictor.from_saved_model(
            numpy_model_paths[-1])

        features, labels = mock_input_generator_eval.create_numpy_data()

        ref_error = self._compute_total_loss(
            labels, [val['logit'].flatten() for val in prediction_ref])

        numpy_predictions = []
        for feature, label in zip(features, labels):
            predicted = numpy_predictor_fn({'x': feature.reshape(1, -1)
                                            })['logit'].flatten()
            numpy_predictions.append(predicted)
            # This ensures that we actually achieve perfect classification.
            if label > 0:
                self.assertGreater(predicted[0], 0)
            else:
                self.assertLess(predicted[0], 0)
        numpy_error = self._compute_total_loss(labels, numpy_predictions)

        # Now we can load our exported estimator graph with the tf_example feed_dict
        # interface, there are no dependencies on the model_fn or preprocessor
        # anymore.
        # We load the latest model since it had the best eval performance.
        tf_example_predictor_fn = tf.contrib.predictor.from_saved_model(
            tf_example_model_paths[-1])
        tf_example_predictions = []
        for feature, label in zip(features, labels):
            # We have to create our serialized tf.Example proto.
            example = tf.train.Example()
            example.features.feature[
                'measured_position'].float_list.value.extend(feature)
            feed_dict = {
                'input_example_tensor':
                np.array(example.SerializeToString()).reshape(1, )
            }
            predicted = tf_example_predictor_fn(feed_dict)['logit'].flatten()
            tf_example_predictions.append(predicted)
            # This ensures that we actually achieve perfect classification.
            if label > 0:
                self.assertGreater(predicted[0], 0)
            else:
                self.assertLess(predicted[0], 0)
        tf_example_error = self._compute_total_loss(labels,
                                                    tf_example_predictions)

        np.testing.assert_almost_equal(tf_example_error, numpy_error)
        # The exported saved models both have to have the same performance and since
        # we train on eval on the same fixed dataset the latest and greatest
        # model error should also be the best.
        np.testing.assert_almost_equal(ref_error, tf_example_error)
Exemplo n.º 10
0
    def test_init_from_checkpoint_use_avg_model_params_and_weights(self):
        """Tests that a simple model trains and exported models are valid."""
        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            use_avg_model_params=True)

        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        mock_input_generator = mocks.MockInputGenerator(batch_size=1)
        mock_input_generator.set_specification_from_model(
            mock_t2r_model, tf.estimator.ModeKeys.TRAIN)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir)

        init_checkpoint = tf.train.NewCheckpointReader(
            tf.train.latest_checkpoint(model_dir))

        # Verify that the serving estimator does exactly the same as the normal
        # estimator with all the parameters.
        initial_estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=model_dir))

        # pylint: disable=g-complex-comprehension
        initial_predictions = [
            prediction['logit'] for prediction in list(
                initial_estimator_predict.predict(
                    input_fn=mock_input_generator.create_dataset_input_fn(
                        mode=tf.estimator.ModeKeys.EVAL)))
        ]

        # The continuous training has its own directory.
        continue_model_dir = self.create_tempdir().full_path
        init_from_checkpoint_fn = functools.partial(
            abstract_model.default_init_from_checkpoint_fn,
            checkpoint=model_dir)
        continue_mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            init_from_checkpoint_fn=init_from_checkpoint_fn)
        continue_mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)
        # Re-initialize the model and train for one step, basically the same
        # performance as the original model.
        train_eval.train_eval_model(
            t2r_model=continue_mock_t2r_model,
            input_generator_train=continue_mock_input_generator_train,
            model_dir=continue_model_dir,
            max_train_steps=_MAX_TRAIN_STEPS)

        continue_checkpoint = tf.train.NewCheckpointReader(
            tf.train.latest_checkpoint(continue_model_dir))

        for tensor_name, _ in tf.train.list_variables(model_dir):
            if 'ExponentialMovingAverage' in tensor_name:
                # These values are replaced by the swapping saver when using the
                # use_avg_model_params.
                continue
            if 'Adam' in tensor_name:
                # The adam optimizer values are not required.
                continue
            if 'global_step' in tensor_name:
                # The global step will be incremented by 1.
                continue
            self.assertAllClose(init_checkpoint.get_tensor(tensor_name),
                                continue_checkpoint.get_tensor(tensor_name),
                                atol=1e-3)

        # Verify that the serving estimator does exactly the same as the normal
        # estimator with all the parameters.
        continue_estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn,
            config=tf.estimator.RunConfig(model_dir=continue_model_dir))

        continue_predictions = [
            prediction['logit'] for prediction in list(
                continue_estimator_predict.predict(
                    input_fn=mock_input_generator.create_dataset_input_fn(
                        mode=tf.estimator.ModeKeys.EVAL)))
        ]

        self.assertTrue(
            np.allclose(initial_predictions, continue_predictions, atol=1e-2))

        # A randomly initialized model estimator with all the parameters.
        random_estimator_predict = tf.estimator.Estimator(
            model_fn=mock_t2r_model.model_fn)

        random_predictions = [
            prediction['logit'] for prediction in list(
                random_estimator_predict.predict(
                    input_fn=mock_input_generator.create_dataset_input_fn(
                        mode=tf.estimator.ModeKeys.EVAL)))
        ]
        self.assertFalse(
            np.allclose(initial_predictions, random_predictions, atol=1e-2))
Exemplo n.º 11
0
  def test_init_from_checkpoint_use_avg_model_params(self):
    """Tests that a simple model trains and exported models are valid."""
    gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100)
    gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10)
    model_dir = self.create_tempdir().full_path
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
        use_avg_model_params=True)

    mock_input_generator_train = mocks.MockInputGenerator(
        batch_size=_BATCH_SIZE)

    mock_input_generator = mocks.MockInputGenerator(batch_size=1)
    mock_input_generator.set_specification_from_model(
        mock_t2r_model, tf.estimator.ModeKeys.TRAIN)

    train_eval.train_eval_model(
        t2r_model=mock_t2r_model,
        input_generator_train=mock_input_generator_train,
        max_train_steps=_MAX_TRAIN_STEPS,
        model_dir=model_dir)

    # Verify that the serving estimator does exactly the same as the normal
    # estimator with all the parameters.
    initial_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn,
        config=tf.estimator.RunConfig(model_dir=model_dir))

    # pylint: disable=g-complex-comprehension
    initial_predictions = [
        prediction['logit'] for prediction in list(
            initial_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]

    # The continuous training has its own directory.
    continue_model_dir = self.create_tempdir().full_path
    init_from_checkpoint_fn = functools.partial(
        abstract_model.default_init_from_checkpoint_fn, checkpoint=model_dir)
    continue_mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
        init_from_checkpoint_fn=init_from_checkpoint_fn)
    continue_mock_input_generator_train = mocks.MockInputGenerator(
        batch_size=_BATCH_SIZE)
    # Re-initialize the model and train for one step, basically the same
    # performance as the original model.
    train_eval.train_eval_model(
        t2r_model=continue_mock_t2r_model,
        input_generator_train=continue_mock_input_generator_train,
        model_dir=continue_model_dir,
        max_train_steps=_MAX_TRAIN_STEPS + 1)

    # Verify that the serving estimator does exactly the same as the normal
    # estimator with all the parameters.
    continue_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn,
        config=tf.estimator.RunConfig(model_dir=continue_model_dir))

    continue_predictions = [
        prediction['logit'] for prediction in list(
            continue_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]

    self.assertTrue(
        np.allclose(initial_predictions, continue_predictions, atol=1e-2))

    # A randomly initialized model estimator with all the parameters.
    random_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn)

    random_predictions = [
        prediction['logit'] for prediction in list(
            random_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]
    self.assertFalse(
        np.allclose(initial_predictions, random_predictions, atol=1e-2))