def test_invalid_logit_fn_results_dict(self): def invalid_logit_fn(features): return {'head1': features['f1'], 'head2': features['f2']} features = {'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string'} with self.assertRaisesRegexp( ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' 'strings to Tensors'): logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', 'fake_params', 'fake_config')
def test_invalid_logit_fn_results_dict(self): def invalid_logit_fn(features): return {'head1': features['f1'], 'head2': features['f2']} features = { 'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string' } with self.assertRaisesRegexp( ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' 'strings to Tensors'): logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', 'fake_params', 'fake_config')
def test_should_return_tensor(self): def invalid_logit_fn(features, params): return { 'tensor1': features['f1'] * params['input_multiplier'], 'tensor2': features['f2'] * params['input_multiplier'] } features = { 'f1': constant_op.constant([[2., 3.]]), 'f2': constant_op.constant([[4., 5.]]) } params = {'learning_rate': 0.001, 'input_multiplier': 2.0} with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'): logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params, 'fake_config')
def test_invalid_logit_fn_results(self): def invalid_logit_fn(features, params): return [ features['f1'] * params['input_multiplier'], features['f2'] * params['input_multiplier'] ] features = { 'f1': constant_op.constant([[2., 3.]]), 'f2': constant_op.constant([[4., 5.]]) } params = {'learning_rate': 0.001, 'input_multiplier': 2.0} with self.assertRaisesRegexp( ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' 'strings to Tensors'): logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params, 'fake_config')
def test_simple_call_multi_logit_fn(self): def dummy_logit_fn(features): return {'head1': features['f1'], 'head2': features['f2']} features = { 'f1': constant_op.constant([[2., 3.]]), 'f2': constant_op.constant([[4., 5.]]) } logit_fn_result = logit_fns.call_logit_fn(dummy_logit_fn, features, model_fn.ModeKeys.TRAIN, 'fake_params', 'fake_config') with session.Session(): self.assertAllClose([[2., 3.]], logit_fn_result['head1'].eval()) self.assertAllClose([[4., 5.]], logit_fn_result['head2'].eval())
def test_simple_call_logit_fn(self): def dummy_logit_fn(features, mode): if mode == model_fn.ModeKeys.TRAIN: return features['f1'] else: return features['f2'] features = { 'f1': constant_op.constant([[2., 3.]]), 'f2': constant_op.constant([[4., 5.]]) } logit_fn_result = logit_fns.call_logit_fn( dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params', 'fake_config') with session.Session(): self.assertAllClose([[4., 5.]], logit_fn_result.eval())