Ejemplo n.º 1
0
    def test_with_mock_training(self):
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            device_type='tpu',
            use_avg_model_params=True)

        mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
        export_dir = os.path.join(model_dir, _EXPORT_DIR)
        hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
            export_dir=export_dir,
            create_export_fn=async_export_hook_builder.default_create_export_fn
        )

        gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1')
        gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1')

        # We optimize our network.
        train_eval.train_eval_model(t2r_model=mock_t2r_model,
                                    input_generator_train=mock_input_generator,
                                    train_hook_builders=[hook_builder],
                                    model_dir=model_dir,
                                    max_train_steps=_MAX_STEPS)
        self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
        self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
        for exported_model_dir in tf.io.gfile.listdir(export_dir):
            self.assertNotEmpty(
                tf.io.gfile.listdir(
                    os.path.join(export_dir, exported_model_dir)))
        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=export_dir)
        self.assertTrue(predictor.restore())
Ejemplo n.º 2
0
 def recordio_train(self, module_name, model_name, file_patterns,
                    **module_kwargs):
     """Trains the model with a RecordIO dataset for a few steps."""
     tf_model = getattr(module_name, model_name)(**module_kwargs)
     params = self._get_params(
         model_dir=self._test_case.create_tempdir().full_path,
         **module_kwargs)
     input_generator = default_input_generator.DefaultRecordInputGenerator(
         file_patterns, batch_size=params['batch_size'])
     initialize_system = tpu.initialize_system
     with mock.patch.object(tpu, 'initialize_system',
                            autospec=True) as mock_init:
         mock_init.side_effect = initialize_system
         train_eval.train_eval_model(
             t2r_model=tf_model,
             input_generator_train=input_generator,
             max_train_steps=params['max_train_steps'],
             model_dir=params['model_dir'],
             use_tpu_wrapper=params['use_tpu_wrapper'])
         if self._use_tpu:
             mock_init.assert_called()
         train_eval_test_utils.assert_output_files(
             test_case=self._test_case,
             model_dir=params['model_dir'],
             expected_output_filename_patterns=train_eval_test_utils.
             DEFAULT_TRAIN_FILENAME_PATTERNS)
Ejemplo n.º 3
0
    def test_predictor(self):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(t2r_model=mock_model,
                                    input_generator_train=input_generator,
                                    max_train_steps=_MAX_TRAIN_STEPS,
                                    model_dir=model_dir)

        predictor = checkpoint_predictor.CheckpointPredictor(
            t2r_model=mock_model, checkpoint_dir=model_dir, use_gpu=False)
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        self.assertEqual(predictor.model_version, -1)
        self.assertEqual(predictor.global_step, -1)
        self.assertTrue(predictor.restore())
        self.assertGreater(predictor.model_version, 0)
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertCountEqual(sorted(predictions.keys()), ['logit'])
        self.assertEqual(predictions['logit'].shape, (2, 1))
Ejemplo n.º 4
0
 def test_predictor_load_final_model(self):
     input_generator = default_input_generator.DefaultRandomInputGenerator(
         batch_size=_BATCH_SIZE)
     model_dir = self.create_tempdir().full_path
     mock_model = mocks.MockT2RModel()
     train_eval.train_eval_model(
         t2r_model=mock_model,
         input_generator_train=input_generator,
         input_generator_eval=input_generator,
         max_train_steps=_MAX_TRAIN_STEPS,
         eval_steps=_MAX_EVAL_STEPS,
         model_dir=model_dir,
         create_exporters_fn=train_eval.create_default_exporters)
     export_dir = os.path.join(model_dir, 'export', 'latest_exporter_numpy')
     final_export_dir = sorted(tf.io.gfile.glob(
         os.path.join(export_dir, '*')),
                               reverse=True)[0]
     predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
         export_dir=final_export_dir)
     predictor.restore()
     self.assertGreater(predictor.model_version, 0)
     self.assertEqual(predictor.global_step, 3)
     ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
         tf.estimator.ModeKeys.PREDICT)
     tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                   ref_feature_spec)
Ejemplo n.º 5
0
    def test_predictor_init_with_default_exporter(self, restore_model_option):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(
            t2r_model=mock_model,
            input_generator_train=input_generator,
            input_generator_eval=input_generator,
            max_train_steps=_MAX_TRAIN_STEPS,
            eval_steps=_MAX_EVAL_STEPS,
            model_dir=model_dir,
            create_exporters_fn=train_eval.create_default_exporters)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, 'export',
                                    'latest_exporter_numpy'),
            restore_model_option=restore_model_option)
        if restore_model_option == exported_savedmodel_predictor.RestoreOptions.RESTORE_SYNCHRONOUSLY:
            predictor.restore()
        self.assertGreater(predictor.model_version, 0)
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
Ejemplo n.º 6
0
  def test_maml_model(self, num_inner_loop_steps):
    model_dir = os.path.join(FLAGS.test_tmpdir, str(num_inner_loop_steps))
    gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                       _MAX_STEPS // 2)
    if tf.io.gfile.exists(model_dir):
      tf.io.gfile.rmtree(model_dir)

    mock_base_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

    mock_tf_model = MockMAMLModel(
        base_model=mock_base_model, num_inner_loop_steps=num_inner_loop_steps)

    # Note, we by choice use the same amount of conditioning samples for
    # inference as well during train and change the model for eval/inference
    # to only produce one output sample.
    mock_input_generator_train = MockMetaInputGenerator(
        batch_size=_BATCH_SIZE,
        num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK,
        num_inference_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK)
    mock_input_generator_train.set_specification_from_model(
        mock_tf_model, mode=tf.estimator.ModeKeys.TRAIN)

    mock_input_generator_eval = MockMetaInputGenerator(
        batch_size=_BATCH_SIZE,
        num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK,
        num_inference_samples_per_task=1)
    mock_input_generator_eval.set_specification_from_model(
        mock_tf_model, mode=tf.estimator.ModeKeys.TRAIN)
    mock_export_generator = MockMetaExportGenerator(
        num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK,
        num_inference_samples_per_task=1)

    train_eval.train_eval_model(
        t2r_model=mock_tf_model,
        input_generator_train=mock_input_generator_train,
        input_generator_eval=mock_input_generator_eval,
        max_train_steps=_MAX_STEPS,
        model_dir=model_dir,
        export_generator=mock_export_generator,
        create_exporters_fn=train_eval.create_default_exporters)
    export_dir = os.path.join(model_dir, 'export')
    # best_exporter_numpy, best_exporter_tf_example.
    self.assertLen(tf.io.gfile.glob(os.path.join(export_dir, '*')), 4)
    numpy_predictor_fn = contrib_predictor.from_saved_model(
        tf.io.gfile.glob(os.path.join(export_dir, 'best_exporter_numpy',
                                      '*'))[-1])

    feed_tensor_keys = sorted(numpy_predictor_fn.feed_tensors.keys())
    self.assertCountEqual(
        ['condition/features/x', 'condition/labels/y', 'inference/features/x'],
        feed_tensor_keys,
    )

    tf_example_predictor_fn = contrib_predictor.from_saved_model(
        tf.io.gfile.glob(
            os.path.join(export_dir, 'best_exporter_tf_example', '*'))[-1])
    self.assertCountEqual(['input_example_tensor'],
                          list(tf_example_predictor_fn.feed_tensors.keys()))
Ejemplo n.º 7
0
 def test_regression_maml(self):
     maml_model = pose_env_maml_models.PoseEnvRegressionModelMAML(
         base_model=pose_env_models.PoseEnvRegressionModel())
     train_eval.train_eval_model(
         t2r_model=maml_model,
         input_generator_train=self._meta_record_input_generator_train,
         input_generator_eval=self._meta_record_input_generator_eval,
         create_exporters_fn=None)
Ejemplo n.º 8
0
    def test_init_from_checkpoint_global_step(self):
        """Tests that a simple model trains and exported models are valid."""
        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 3)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir,
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)
        # The model trains for 200 steps and saves a checkpoint each 100 steps and
        # keeps 3 -> len == 3.
        self.assertLen(
            tf.io.gfile.glob(os.path.join(model_dir, 'model*.meta')), 3)

        # The continuous training has its own directory.
        continue_model_dir = self.create_tempdir().full_path
        init_from_checkpoint_fn = functools.partial(
            abstract_model.default_init_from_checkpoint_fn,
            checkpoint=model_dir)
        continue_mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            init_from_checkpoint_fn=init_from_checkpoint_fn)
        continue_mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)
        train_eval.train_eval_model(
            t2r_model=continue_mock_t2r_model,
            input_generator_train=continue_mock_input_generator_train,
            model_dir=continue_model_dir,
            max_train_steps=_MAX_TRAIN_STEPS + 100,
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)
        # If the model was successful restored including the global step, only 1
        # additional checkpoint to the init one should be created -> len == 2.
        self.assertLen(
            tf.io.gfile.glob(os.path.join(continue_model_dir, 'model*.meta')),
            2)
Ejemplo n.º 9
0
    def test_freezing_some_variables(self):
        """Tests we can freeze training for parts of the network."""
        def freeze_biases(var):
            # Update all variables except bias variables.
            return 'bias' not in var.name

        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('create_train_op.filter_trainables_fn',
                           freeze_biases)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)
        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir)

        start_checkpoint = tf.train.NewCheckpointReader(
            os.path.join(model_dir, 'model.ckpt-0'))
        last_checkpoint = tf.train.NewCheckpointReader(
            tf.train.latest_checkpoint(model_dir))
        for var_name, _ in tf.train.list_variables(model_dir):
            # Some of the batch norm moving averages are constant over training on the
            # mock data used.
            if 'batch_norm' in var_name:
                continue
            if 'bias' not in var_name:
                # Should update.
                self.assertNotAllClose(start_checkpoint.get_tensor(var_name),
                                       last_checkpoint.get_tensor(var_name),
                                       atol=1e-3)
            else:
                # Should not update.
                self.assertAllClose(start_checkpoint.get_tensor(var_name),
                                    last_checkpoint.get_tensor(var_name),
                                    atol=1e-3)
    def test_predictor_with_async_hook(self):
        model_dir = self.create_tempdir().full_path
        export_dir = os.path.join(model_dir, _EXPORT_DIR)
        hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
            export_dir=export_dir,
            create_export_fn=async_export_hook_builder.default_create_export_fn
        )
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(t2r_model=mock_model,
                                    input_generator_train=input_generator,
                                    train_hook_builders=[hook_builder],
                                    max_train_steps=_MAX_TRAIN_STEPS,
                                    model_dir=model_dir)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, _EXPORT_DIR))
        with self.assertRaises(ValueError):
            predictor.get_feature_specification()
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        with self.assertRaises(ValueError):
            _ = predictor.model_version
        self.assertEqual(predictor.global_step, -1)
        self.assertTrue(predictor.restore())
        self.assertGreater(predictor.model_version, 0)
        # NOTE: The async hook builder will export the global step.
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertCountEqual(sorted(predictions.keys()), ['logit'])
        self.assertEqual(predictions['logit'].shape, (2, 1))
Ejemplo n.º 11
0
    def test_predictor_with_default_exporter(self, is_async):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(
            t2r_model=mock_model,
            input_generator_train=input_generator,
            input_generator_eval=input_generator,
            max_train_steps=_MAX_TRAIN_STEPS,
            eval_steps=_MAX_EVAL_STEPS,
            model_dir=model_dir,
            create_exporters_fn=train_eval.create_default_exporters)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, 'export',
                                    'latest_exporter_numpy'))
        with self.assertRaises(ValueError):
            predictor.get_feature_specification()
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        with self.assertRaises(ValueError):
            _ = predictor.model_version
        self.assertEqual(predictor.global_step, -1)
        self.assertTrue(predictor.restore(is_async=is_async))
        self.assertGreater(predictor.model_version, 0)
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertCountEqual(predictions.keys(), ['logit'])
        self.assertEqual(predictions['logit'].shape, (2, 1))
Ejemplo n.º 12
0
def main(unused_argv):
  gin.parse_config_files_and_bindings(
      FLAGS.gin_configs, FLAGS.gin_bindings, print_includes_and_imports=True)
  train_eval.train_eval_model()
Ejemplo n.º 13
0
def test_train_eval_gin(test_case,
                        model_dir,
                        full_gin_path,
                        max_train_steps,
                        eval_steps,
                        gin_overwrites_fn=None,
                        assert_train_output_files=True,
                        assert_eval_output_files=True):
    """Train and eval a runnable gin config.

  Until we have a proper gen_rule to create individual targets for every gin
  file automatically, gin files can be tested using the pattern below.
  Please, use 'test_train_eval_gin' as the test function name such that it
  is easy to convert these tests as soon as the gen_rule is available.

  @parameterized.parameters(
      ('first.gin',),
      ('second.gin',),
      ('third.gin',),
  )
  def test_train_eval_gin(self, gin_file):
    full_gin_path = os.path.join(FLAGS.test_srcdir, BASE_GIN_PATH, gin_file)
    model_dir = os.path.join(FLAGS.test_tmpdir, 'test_train_eval_gin', gin_file)
    train_eval_test_utils.test_train_eval_gin(
        test_case=self,
        model_dir=model_dir,
        full_gin_path=full_gin_path,
        max_train_steps=MAX_TRAIN_STEPS,
        eval_steps=EVAL_STEPS)

  Args:
    test_case: The instance of the test used to assert that the output files are
      generated.
    model_dir: The path where the model should be stored.
    full_gin_path: The path of the gin file which parameterizes train_eval.
    max_train_steps: The maximum number of training steps, should be small since
      this is just for testing.
    eval_steps: The number of eval steps, should be small since this is just for
      testing.
    gin_overwrites_fn: Optional function which binds gin parameters to
      overwrite.
    assert_train_output_files: If True, the expected output files of the
      training run are checked, otherwise this check is skipped. If only
      evaluation is performed this should be set to False.
    assert_eval_output_files: If True, the output expected files of the
      evaluation run are checked, otherwise this check is skipped. If only
      training is performed this should be set to False. Note, if
      assert_train_output_files is set to False the model_dir is not deleted
      in order to load the model from training.
  """
    # We clear all prior parameters set by gin to ensure that we can call this
    # function sequentially for all parameterized tests.
    gin.clear_config(clear_constants=True)

    gin.parse_config_file(full_gin_path, print_includes_and_imports=True)
    gin.bind_parameter('train_eval_model.model_dir', model_dir)

    if gin_overwrites_fn is not None:
        gin_overwrites_fn()

    # Make sure that the model dir is empty. This is important for running
    # tests locally.
    if tf.io.gfile.exists(model_dir) and assert_train_output_files:
        tf.io.gfile.rmtree(model_dir)

    train_eval.train_eval_model(model_dir=model_dir,
                                max_train_steps=max_train_steps,
                                eval_steps=eval_steps,
                                create_exporters_fn=None)
    if assert_train_output_files:
        assert_output_files(
            test_case=test_case,
            model_dir=model_dir,
            expected_output_filename_patterns=DEFAULT_TRAIN_FILENAME_PATTERNS)
    if assert_eval_output_files:
        assert_output_files(
            test_case=test_case,
            model_dir=model_dir,
            expected_output_filename_patterns=DEFAULT_EVAL_FILENAME_PATTERNS)
Ejemplo n.º 14
0
 def test_regression(self):
     train_eval.train_eval_model(
         t2r_model=pose_env_models.PoseEnvRegressionModel(),
         input_generator_train=self._record_input_generator,
         input_generator_eval=self._record_input_generator,
         create_exporters_fn=None)
Ejemplo n.º 15
0
 def test_mc(self):
     train_eval.train_eval_model(
         t2r_model=pose_env_models.PoseEnvContinuousMCModel(),
         input_generator_train=self._record_input_generator,
         input_generator_eval=self._record_input_generator,
         create_exporters_fn=None)
Ejemplo n.º 16
0
def main(unused_argv):
  gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings)
  train_eval.train_eval_model()
Ejemplo n.º 17
0
  def test_train_eval_model(self):
    """Tests that a simple model trains and exported models are valid."""
    gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100)
    model_dir = self.create_tempdir().full_path
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

    mock_input_generator_train = mocks.MockInputGenerator(
        batch_size=_BATCH_SIZE)
    mock_input_generator_eval = mocks.MockInputGenerator(batch_size=1)
    fake_hook_builder = FakeHookBuilder()

    train_eval.train_eval_model(
        t2r_model=mock_t2r_model,
        input_generator_train=mock_input_generator_train,
        input_generator_eval=mock_input_generator_eval,
        max_train_steps=_MAX_TRAIN_STEPS,
        model_dir=model_dir,
        train_hook_builders=[fake_hook_builder],
        eval_hook_builders=[fake_hook_builder],
        eval_steps=_EVAL_STEPS,
        eval_throttle_secs=_EVAL_THROTTLE_SECS,
        create_exporters_fn=train_eval.create_default_exporters)

    self.assertTrue(fake_hook_builder.hook_mock.begin.called)

    # We ensure that both numpy and tf_example inference models are exported.
    best_exporter_numpy_path = os.path.join(model_dir, 'export',
                                            'best_exporter_numpy', '*')
    numpy_model_paths = sorted(tf.io.gfile.glob(best_exporter_numpy_path))
    # There should be at least 1 exported model.
    self.assertGreater(len(numpy_model_paths), 0)
    # This mock network converges nicely which is why we have several best
    # models, by default we keep the best 5 and the latest one is always the
    # best.
    self.assertLessEqual(len(numpy_model_paths), 5)

    best_exporter_tf_example_path = os.path.join(
        model_dir, 'export', 'best_exporter_tf_example', '*')

    tf_example_model_paths = sorted(
        tf.io.gfile.glob(best_exporter_tf_example_path))
    # There should be at least 1 exported model.
    self.assertGreater(len(tf_example_model_paths), 0)
    # This mock network converges nicely which is why we have several best
    # models, by default we keep the best 5 and the latest one is always the
    # best.
    self.assertLessEqual(len(tf_example_model_paths), 5)

    # We test both saved models within one test since the bulk of the time
    # is spent training the model in the firstplace.

    # Verify that the serving estimator does exactly the same as the normal
    # estimator with all the parameters.
    estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn,
        config=tf.estimator.RunConfig(model_dir=model_dir))

    prediction_ref = estimator_predict.predict(
        input_fn=mock_input_generator_eval.create_dataset_input_fn(
            mode=tf.estimator.ModeKeys.EVAL))

    # Now we can load our exported estimator graph with the numpy feed_dict
    # interface, there are no dependencies on the model_fn or preprocessor
    # anymore.
    # We load the latest model since it had the best eval performance.
    numpy_predictor_fn = contrib_predictor.from_saved_model(
        numpy_model_paths[-1])

    features, labels = mock_input_generator_eval.create_numpy_data()

    ref_error = self._compute_total_loss(
        labels, [val['logit'].flatten() for val in prediction_ref])

    numpy_predictions = []
    for feature, label in zip(features, labels):
      predicted = numpy_predictor_fn({'x': feature.reshape(
          1, -1)})['logit'].flatten()
      numpy_predictions.append(predicted)
      # This ensures that we actually achieve near-perfect classification.
      if label > 0:
        self.assertGreater(predicted[0], 0)
      else:
        self.assertLess(predicted[0], 0)
    numpy_error = self._compute_total_loss(labels, numpy_predictions)

    # Now we can load our exported estimator graph with the tf_example feed_dict
    # interface, there are no dependencies on the model_fn or preprocessor
    # anymore.
    # We load the latest model since it had the best eval performance.
    tf_example_predictor_fn = contrib_predictor.from_saved_model(
        tf_example_model_paths[-1])
    tf_example_predictions = []
    for feature, label in zip(features, labels):
      # We have to create our serialized tf.Example proto.
      example = tf.train.Example()
      example.features.feature['measured_position'].float_list.value.extend(
          feature)
      feed_dict = {
          'input_example_tensor':
              np.array(example.SerializeToString()).reshape(1,)
      }
      predicted = tf_example_predictor_fn(feed_dict)['logit'].flatten()
      tf_example_predictions.append(predicted)
      # This ensures that we actually achieve perfect classification.
      if label > 0:
        self.assertGreater(predicted[0], 0)
      else:
        self.assertLess(predicted[0], 0)
    tf_example_error = self._compute_total_loss(labels, tf_example_predictions)

    np.testing.assert_almost_equal(tf_example_error, numpy_error)
    # The exported saved models both have to have the same performance and since
    # we train on eval on the same fixed dataset the latest and greatest
    # model error should also be the best.
    np.testing.assert_almost_equal(ref_error, tf_example_error, decimal=3)
Ejemplo n.º 18
0
  def test_init_from_checkpoint_use_avg_model_params_and_weights(self):
    """Tests that a simple model trains and exported models are valid."""
    gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100)
    gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 3)
    model_dir = self.create_tempdir().full_path
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
        use_avg_model_params=True)

    mock_input_generator_train = mocks.MockInputGenerator(
        batch_size=_BATCH_SIZE)

    mock_input_generator = mocks.MockInputGenerator(batch_size=1)
    mock_input_generator.set_specification_from_model(
        mock_t2r_model, tf.estimator.ModeKeys.TRAIN)

    train_eval.train_eval_model(
        t2r_model=mock_t2r_model,
        input_generator_train=mock_input_generator_train,
        max_train_steps=_MAX_TRAIN_STEPS,
        model_dir=model_dir)

    init_checkpoint = tf.train.NewCheckpointReader(
        tf.train.latest_checkpoint(model_dir))

    # Verify that the serving estimator does exactly the same as the normal
    # estimator with all the parameters.
    initial_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn,
        config=tf.estimator.RunConfig(model_dir=model_dir))

    # pylint: disable=g-complex-comprehension
    initial_predictions = [
        prediction['logit'] for prediction in list(
            initial_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]

    # The continuous training has its own directory.
    continue_model_dir = self.create_tempdir().full_path
    init_from_checkpoint_fn = functools.partial(
        abstract_model.default_init_from_checkpoint_fn, checkpoint=model_dir)
    continue_mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
        init_from_checkpoint_fn=init_from_checkpoint_fn)
    continue_mock_input_generator_train = mocks.MockInputGenerator(
        batch_size=_BATCH_SIZE)
    # Re-initialize the model and train for one step, basically the same
    # performance as the original model.
    train_eval.train_eval_model(
        t2r_model=continue_mock_t2r_model,
        input_generator_train=continue_mock_input_generator_train,
        model_dir=continue_model_dir,
        max_train_steps=_MAX_TRAIN_STEPS)

    continue_checkpoint = tf.train.NewCheckpointReader(
        tf.train.latest_checkpoint(continue_model_dir))

    for tensor_name, _ in tf.train.list_variables(model_dir):
      if 'ExponentialMovingAverage' in tensor_name:
        # These values are replaced by the swapping saver when using the
        # use_avg_model_params.
        continue
      if 'Adam' in tensor_name:
        # The adam optimizer values are not required.
        continue
      if 'global_step' in tensor_name:
        # The global step will be incremented by 1.
        continue
      self.assertAllClose(
          init_checkpoint.get_tensor(tensor_name),
          continue_checkpoint.get_tensor(tensor_name),
          atol=1e-3)

    # Verify that the serving estimator does exactly the same as the normal
    # estimator with all the parameters.
    continue_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn,
        config=tf.estimator.RunConfig(model_dir=continue_model_dir))

    continue_predictions = [
        prediction['logit'] for prediction in list(
            continue_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]

    self.assertTrue(
        np.allclose(initial_predictions, continue_predictions, atol=1e-1))

    # A randomly initialized model estimator with all the parameters.
    random_estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn)

    random_predictions = [
        prediction['logit'] for prediction in list(
            random_estimator_predict.predict(
                input_fn=mock_input_generator.create_dataset_input_fn(
                    mode=tf.estimator.ModeKeys.EVAL)))
    ]
    self.assertFalse(
        np.allclose(initial_predictions, random_predictions, atol=1e-2))