示例#1
0
  def test_invalid_logit_fn_results_dict(self):

    def invalid_logit_fn(features):
      return {'head1': features['f1'], 'head2': features['f2']}

    features = {'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string'}
    with self.assertRaisesRegexp(
        ValueError, 'logit_fn should return a Tensor or a dictionary mapping '
                    'strings to Tensors'):
      logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode',
                              'fake_params', 'fake_config')
示例#2
0
    def test_invalid_logit_fn_results_dict(self):
        def invalid_logit_fn(features):
            return {'head1': features['f1'], 'head2': features['f2']}

        features = {
            'f1': constant_op.constant([[2., 3.]]),
            'f2': 'some string'
        }
        with self.assertRaisesRegexp(
                ValueError,
                'logit_fn should return a Tensor or a dictionary mapping '
                'strings to Tensors'):
            logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode',
                                    'fake_params', 'fake_config')
  def test_should_return_tensor(self):

    def invalid_logit_fn(features, params):
      return {
          'tensor1': features['f1'] * params['input_multiplier'],
          'tensor2': features['f2'] * params['input_multiplier']
      }
    features = {
        'f1': constant_op.constant([[2., 3.]]),
        'f2': constant_op.constant([[4., 5.]])
    }
    params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
    with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'):
      logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
                              'fake_config')
示例#4
0
  def test_invalid_logit_fn_results(self):

    def invalid_logit_fn(features, params):
      return [
          features['f1'] * params['input_multiplier'],
          features['f2'] * params['input_multiplier']
      ]

    features = {
        'f1': constant_op.constant([[2., 3.]]),
        'f2': constant_op.constant([[4., 5.]])
    }
    params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
    with self.assertRaisesRegexp(
        ValueError, 'logit_fn should return a Tensor or a dictionary mapping '
                    'strings to Tensors'):
      logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
                              'fake_config')
示例#5
0
    def test_invalid_logit_fn_results(self):
        def invalid_logit_fn(features, params):
            return [
                features['f1'] * params['input_multiplier'],
                features['f2'] * params['input_multiplier']
            ]

        features = {
            'f1': constant_op.constant([[2., 3.]]),
            'f2': constant_op.constant([[4., 5.]])
        }
        params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
        with self.assertRaisesRegexp(
                ValueError,
                'logit_fn should return a Tensor or a dictionary mapping '
                'strings to Tensors'):
            logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode',
                                    params, 'fake_config')
示例#6
0
    def test_simple_call_multi_logit_fn(self):
        def dummy_logit_fn(features):
            return {'head1': features['f1'], 'head2': features['f2']}

        features = {
            'f1': constant_op.constant([[2., 3.]]),
            'f2': constant_op.constant([[4., 5.]])
        }
        logit_fn_result = logit_fns.call_logit_fn(dummy_logit_fn, features,
                                                  model_fn.ModeKeys.TRAIN,
                                                  'fake_params', 'fake_config')
        with session.Session():
            self.assertAllClose([[2., 3.]], logit_fn_result['head1'].eval())
            self.assertAllClose([[4., 5.]], logit_fn_result['head2'].eval())
 def test_simple_call_logit_fn(self):
   def dummy_logit_fn(features, mode):
     if mode == model_fn.ModeKeys.TRAIN:
       return features['f1']
     else:
       return features['f2']
   features = {
       'f1': constant_op.constant([[2., 3.]]),
       'f2': constant_op.constant([[4., 5.]])
   }
   logit_fn_result = logit_fns.call_logit_fn(
       dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params',
       'fake_config')
   with session.Session():
     self.assertAllClose([[4., 5.]], logit_fn_result.eval())
示例#8
0
  def test_simple_call_multi_logit_fn(self):

    def dummy_logit_fn(features):
      return {'head1': features['f1'], 'head2': features['f2']}

    features = {
        'f1': constant_op.constant([[2., 3.]]),
        'f2': constant_op.constant([[4., 5.]])
    }
    logit_fn_result = logit_fns.call_logit_fn(dummy_logit_fn, features,
                                              model_fn.ModeKeys.TRAIN,
                                              'fake_params', 'fake_config')
    with session.Session():
      self.assertAllClose([[2., 3.]], logit_fn_result['head1'].eval())
      self.assertAllClose([[4., 5.]], logit_fn_result['head2'].eval())