def test_invalid_logit_fn_results_dict(self): def invalid_logit_fn(features): return {'head1': features['f1'], 'head2': features['f2']} features = {'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string'} with self.assertRaisesRegexp( ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' 'strings to Tensors'): model_fn.call_logit_fn(invalid_logit_fn, features, 'fake_mode', 'fake_params', 'fake_config')
def test_invalid_logit_fn_results(self): def invalid_logit_fn(features, params): return [ features['f1'] * params['input_multiplier'], features['f2'] * params['input_multiplier'] ] features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])} params = {'learning_rate': 0.001, 'input_multiplier': 2.0} with self.assertRaisesRegexp( ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' 'strings to Tensors'): model_fn.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params, 'fake_config')
def test_simple_call_multi_logit_fn(self): def dummy_logit_fn(features): return {u'head1': features['f1'], 'head2': features['f2']} features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])} logit_fn_result = model_fn.call_logit_fn(dummy_logit_fn, features, ModeKeys.TRAIN, 'fake_params', 'fake_config') with self.cached_session(): self.assertAllClose([[2., 3.]], self.evaluate(logit_fn_result['head1'])) self.assertAllClose([[4., 5.]], self.evaluate(logit_fn_result['head2']))
def test_simple_call_logit_fn(self): def dummy_logit_fn(features, mode): if mode == ModeKeys.TRAIN: return features['f1'] else: return features['f2'] features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])} logit_fn_result = model_fn.call_logit_fn(dummy_logit_fn, features, ModeKeys.EVAL, 'fake_params', 'fake_config') with self.cached_session(): self.assertAllClose([[4., 5.]], self.evaluate(logit_fn_result))