Example #1
0
  def test_invalid_logit_fn_results_dict(self):

    def invalid_logit_fn(features):
      return {'head1': features['f1'], 'head2': features['f2']}

    features = {'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string'}
    with self.assertRaisesRegexp(
        ValueError, 'logit_fn should return a Tensor or a dictionary mapping '
        'strings to Tensors'):
      model_fn.call_logit_fn(invalid_logit_fn, features, 'fake_mode',
                             'fake_params', 'fake_config')
Example #2
0
  def test_invalid_logit_fn_results(self):

    def invalid_logit_fn(features, params):
      return [
          features['f1'] * params['input_multiplier'],
          features['f2'] * params['input_multiplier']
      ]

    features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])}
    params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
    with self.assertRaisesRegexp(
        ValueError, 'logit_fn should return a Tensor or a dictionary mapping '
        'strings to Tensors'):
      model_fn.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
                             'fake_config')
Example #3
0
  def test_simple_call_multi_logit_fn(self):

    def dummy_logit_fn(features):
      return {u'head1': features['f1'], 'head2': features['f2']}

    features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])}
    logit_fn_result = model_fn.call_logit_fn(dummy_logit_fn, features,
                                             ModeKeys.TRAIN, 'fake_params',
                                             'fake_config')
    with self.cached_session():
      self.assertAllClose([[2., 3.]], self.evaluate(logit_fn_result['head1']))
      self.assertAllClose([[4., 5.]], self.evaluate(logit_fn_result['head2']))
Example #4
0
  def test_simple_call_logit_fn(self):

    def dummy_logit_fn(features, mode):
      if mode == ModeKeys.TRAIN:
        return features['f1']
      else:
        return features['f2']

    features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])}
    logit_fn_result = model_fn.call_logit_fn(dummy_logit_fn, features,
                                             ModeKeys.EVAL, 'fake_params',
                                             'fake_config')
    with self.cached_session():
      self.assertAllClose([[4., 5.]], self.evaluate(logit_fn_result))