Esempio n. 1
0
    def test_with_mock_training(self):
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            device_type='tpu',
            use_avg_model_params=True)

        mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
        export_dir = os.path.join(model_dir, _EXPORT_DIR)
        hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
            export_dir=export_dir,
            create_export_fn=async_export_hook_builder.default_create_export_fn
        )

        gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1')
        gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1')

        # We optimize our network.
        train_eval.train_eval_model(t2r_model=mock_t2r_model,
                                    input_generator_train=mock_input_generator,
                                    train_hook_builders=[hook_builder],
                                    model_dir=model_dir,
                                    max_train_steps=_MAX_STEPS)
        self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
        self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
        for exported_model_dir in tf.io.gfile.listdir(export_dir):
            self.assertNotEmpty(
                tf.io.gfile.listdir(
                    os.path.join(export_dir, exported_model_dir)))
        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=export_dir)
        self.assertTrue(predictor.restore())
Esempio n. 2
0
 def test_predictor_load_final_model(self):
     input_generator = default_input_generator.DefaultRandomInputGenerator(
         batch_size=_BATCH_SIZE)
     model_dir = self.create_tempdir().full_path
     mock_model = mocks.MockT2RModel()
     train_eval.train_eval_model(
         t2r_model=mock_model,
         input_generator_train=input_generator,
         input_generator_eval=input_generator,
         max_train_steps=_MAX_TRAIN_STEPS,
         eval_steps=_MAX_EVAL_STEPS,
         model_dir=model_dir,
         create_exporters_fn=train_eval.create_default_exporters)
     export_dir = os.path.join(model_dir, 'export', 'latest_exporter_numpy')
     final_export_dir = sorted(tf.io.gfile.glob(
         os.path.join(export_dir, '*')),
                               reverse=True)[0]
     predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
         export_dir=final_export_dir)
     predictor.restore()
     self.assertGreater(predictor.model_version, 0)
     self.assertEqual(predictor.global_step, 3)
     ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
         tf.estimator.ModeKeys.PREDICT)
     tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                   ref_feature_spec)
Esempio n. 3
0
    def test_predictor_init_with_default_exporter(self, restore_model_option):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(
            t2r_model=mock_model,
            input_generator_train=input_generator,
            input_generator_eval=input_generator,
            max_train_steps=_MAX_TRAIN_STEPS,
            eval_steps=_MAX_EVAL_STEPS,
            model_dir=model_dir,
            create_exporters_fn=train_eval.create_default_exporters)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, 'export',
                                    'latest_exporter_numpy'),
            restore_model_option=restore_model_option)
        if restore_model_option == exported_savedmodel_predictor.RestoreOptions.RESTORE_SYNCHRONOUSLY:
            predictor.restore()
        self.assertGreater(predictor.model_version, 0)
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
Esempio n. 4
0
  def test_with_mock_training(self):
    model_dir = self.create_tempdir().full_path
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor, device_type='cpu')

    mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)
    export_dir = os.path.join(model_dir, _EXPORT_DIR)
    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)
    hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
        export_dir=export_dir, create_export_fn=default_create_export_fn)

    default_create_export_fn = functools.partial(
        async_export_hook_builder.default_create_export_fn,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)

    # We optimize our network.
    train_eval.train_eval_model(
        t2r_model=mock_t2r_model,
        input_generator_train=mock_input_generator,
        train_hook_builders=[hook_builder],
        model_dir=model_dir,
        max_train_steps=_MAX_STEPS)
    self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
    self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
    for exported_model_dir in tf.io.gfile.listdir(export_dir):
      self.assertNotEmpty(
          tf.io.gfile.listdir(os.path.join(export_dir, exported_model_dir)))
    predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
        export_dir=export_dir)
    self.assertTrue(predictor.restore())
Esempio n. 5
0
    def test_predictor(self):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(
            t2r_model=mock_model,
            input_generator_train=input_generator,
            input_generator_eval=input_generator,
            max_train_steps=_MAX_TRAIN_STEPS,
            eval_steps=_MAX_EVAL_STEPS,
            model_dir=model_dir,
            create_exporters_fn=train_eval.create_default_exporters)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, 'export',
                                    'latest_exporter_numpy'))
        with self.assertRaises(ValueError):
            predictor.get_feature_specification()
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        with self.assertRaises(ValueError):
            _ = predictor.model_version
        self.assertTrue(predictor.restore())
        self.assertGreater(predictor.model_version, 0)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertEqual(predictions['logit'].shape, (2, 1))
    def test_predictor_with_async_hook(self):
        model_dir = self.create_tempdir().full_path
        default_create_export_fn = functools.partial(
            async_export_hook_builder.default_create_export_fn,
            batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT)
        export_dir = os.path.join(model_dir, _EXPORT_DIR)
        hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
            export_dir=export_dir, create_export_fn=default_create_export_fn)
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(t2r_model=mock_model,
                                    input_generator_train=input_generator,
                                    train_hook_builders=[hook_builder],
                                    max_train_steps=_MAX_TRAIN_STEPS,
                                    model_dir=model_dir)

        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=os.path.join(model_dir, _EXPORT_DIR))
        with self.assertRaises(ValueError):
            predictor.get_feature_specification()
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        with self.assertRaises(ValueError):
            _ = predictor.model_version
        self.assertEqual(predictor.global_step, -1)
        self.assertTrue(predictor.restore())
        self.assertGreater(predictor.model_version, 0)
        # NOTE: The async hook builder will export the global step.
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertCountEqual(sorted(predictions.keys()), ['logit'])
        self.assertEqual(predictions['logit'].shape, (2, 1))
 def test_predictor_timeout(self):
     predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
         export_dir='/random/path/which/does/not/exist', timeout=1)
     self.assertFalse(predictor.restore())