예제 #1
0
 def test_predictor_raises(self):
     mock_model = mocks.MockT2RModel()
     # Raises because no checkpoint_dir and has been set and restore is called.
     predictor = checkpoint_predictor.CheckpointPredictor(
         t2r_model=mock_model)
     with self.assertRaises(ValueError):
         predictor.restore()
예제 #2
0
    def test_predictor(self):
        input_generator = default_input_generator.DefaultRandomInputGenerator(
            batch_size=_BATCH_SIZE)
        model_dir = self.create_tempdir().full_path
        mock_model = mocks.MockT2RModel()
        train_eval.train_eval_model(t2r_model=mock_model,
                                    input_generator_train=input_generator,
                                    max_train_steps=_MAX_TRAIN_STEPS,
                                    model_dir=model_dir)

        predictor = checkpoint_predictor.CheckpointPredictor(
            t2r_model=mock_model, checkpoint_dir=model_dir, use_gpu=False)
        with self.assertRaises(ValueError):
            predictor.predict({'does_not_matter': np.zeros(1)})
        self.assertEqual(predictor.model_version, -1)
        self.assertEqual(predictor.global_step, -1)
        self.assertTrue(predictor.restore())
        self.assertGreater(predictor.model_version, 0)
        self.assertEqual(predictor.global_step, 3)
        ref_feature_spec = mock_model.preprocessor.get_in_feature_specification(
            tf.estimator.ModeKeys.PREDICT)
        tensorspec_utils.assert_equal(predictor.get_feature_specification(),
                                      ref_feature_spec)
        features = tensorspec_utils.make_random_numpy(ref_feature_spec,
                                                      batch_size=_BATCH_SIZE)
        predictions = predictor.predict(features)
        self.assertLen(predictions, 1)
        self.assertCountEqual(sorted(predictions.keys()), ['logit'])
        self.assertEqual(predictions['logit'].shape, (2, 1))
예제 #3
0
 def test_predictor_timeout(self):
     mock_model = mocks.MockT2RModel()
     predictor = checkpoint_predictor.CheckpointPredictor(
         t2r_model=mock_model,
         checkpoint_dir='/random/path/which/does/not/exist',
         timeout=1)
     self.assertFalse(predictor.restore())
예제 #4
0
 def test_regression_maml_policy_interface(self):
   t2r_model = pose_env_maml_models.PoseEnvRegressionModelMAML(
       base_model=pose_env_models.PoseEnvRegressionModel(),
       preprocessor_cls=preprocessors.FixedLenMetaExamplePreprocessor)
   predictor = checkpoint_predictor.CheckpointPredictor(t2r_model=t2r_model)
   predictor.init_randomly()
   policy = meta_policies.MAMLRegressionPolicy(t2r_model, predictor=predictor)
   self._test_policy_interface(policy, restore=False)