def test_predictor_raises(self): mock_model = mocks.MockT2RModel() # Raises because no checkpoint_dir and has been set and restore is called. predictor = checkpoint_predictor.CheckpointPredictor( t2r_model=mock_model) with self.assertRaises(ValueError): predictor.restore()
def test_predictor(self): input_generator = default_input_generator.DefaultRandomInputGenerator( batch_size=_BATCH_SIZE) model_dir = self.create_tempdir().full_path mock_model = mocks.MockT2RModel() train_eval.train_eval_model(t2r_model=mock_model, input_generator_train=input_generator, max_train_steps=_MAX_TRAIN_STEPS, model_dir=model_dir) predictor = checkpoint_predictor.CheckpointPredictor( t2r_model=mock_model, checkpoint_dir=model_dir, use_gpu=False) with self.assertRaises(ValueError): predictor.predict({'does_not_matter': np.zeros(1)}) self.assertEqual(predictor.model_version, -1) self.assertEqual(predictor.global_step, -1) self.assertTrue(predictor.restore()) self.assertGreater(predictor.model_version, 0) self.assertEqual(predictor.global_step, 3) ref_feature_spec = mock_model.preprocessor.get_in_feature_specification( tf.estimator.ModeKeys.PREDICT) tensorspec_utils.assert_equal(predictor.get_feature_specification(), ref_feature_spec) features = tensorspec_utils.make_random_numpy(ref_feature_spec, batch_size=_BATCH_SIZE) predictions = predictor.predict(features) self.assertLen(predictions, 1) self.assertCountEqual(sorted(predictions.keys()), ['logit']) self.assertEqual(predictions['logit'].shape, (2, 1))
def test_predictor_timeout(self): mock_model = mocks.MockT2RModel() predictor = checkpoint_predictor.CheckpointPredictor( t2r_model=mock_model, checkpoint_dir='/random/path/which/does/not/exist', timeout=1) self.assertFalse(predictor.restore())
def test_regression_maml_policy_interface(self): t2r_model = pose_env_maml_models.PoseEnvRegressionModelMAML( base_model=pose_env_models.PoseEnvRegressionModel(), preprocessor_cls=preprocessors.FixedLenMetaExamplePreprocessor) predictor = checkpoint_predictor.CheckpointPredictor(t2r_model=t2r_model) predictor.init_randomly() policy = meta_policies.MAMLRegressionPolicy(t2r_model, predictor=predictor) self._test_policy_interface(policy, restore=False)