Exemple #1
0
  def test_get_layer_policy(self):
    layer = core.Dense(4)
    self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float32')

    p = policy.Policy('mixed_float16')
    layer = core.Dense(4, dtype=p)
    self.assertIs(get_layer_policy.get_layer_policy(layer), p)

    layer = core.Dense(4, dtype='float64')
    self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float64')
Exemple #2
0
 def test_mixed_policies_(self, strategy_fn):
   strategy = strategy_fn()
   for dtype in 'float16', 'bfloat16':
     x = tf.constant([1.])
     policy_name = 'mixed_' + dtype
     with strategy.scope(), policy.policy_scope(policy_name):
       layer = mp_test_util.MultiplyLayer(assert_type=dtype)
       self.assertEqual(layer.dtype, tf.float32)
       self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
                        policy_name)
       y = layer(x)
       self.assertEqual(layer.v.dtype, tf.float32)
       self.assertEqual(y.dtype, dtype)
       self.assertEqual(layer.dtype_policy.name, policy_name)
       self.assertIsInstance(layer.dtype_policy, policy.Policy)
       self.assertEqual(layer.compute_dtype, dtype)
       self.assertEqual(layer.dtype, tf.float32)
       self.assertEqual(layer.variable_dtype, tf.float32)
       self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
                        policy_name)
       self.evaluate(tf.compat.v1.global_variables_initializer())
       self.assertEqual(self.evaluate(y), 1.)
Exemple #3
0
    def _test_saving(self, model, dataset, save_format, use_regularizer):
        # Save and load model, asserting variable does not change
        save_path = os.path.join(self.get_temp_dir(), 'model')
        model.save(save_path, save_format=save_format)
        model = save.load_model(save_path)
        (layer, ) = (layer for layer in model.layers
                     if 'MultiplyLayer' in layer.__class__.__name__)
        expected = 1 - 2**-14
        if use_regularizer:
            expected -= 2 * 2**-14
        self.assertEqual(backend.eval(layer.v), expected)

        # Continue training, and assert variable is correct value
        model.fit(dataset)
        new_expected = expected - 2**-14
        if use_regularizer:
            new_expected -= 2 * 2**-14
        self.assertEqual(backend.eval(layer.v), new_expected)

        # Load saved model again, and assert variable is previous value
        model = save.load_model(save_path)
        (layer, ) = (layer for layer in model.layers
                     if 'MultiplyLayer' in layer.__class__.__name__)
        self.assertEqual(backend.eval(layer.v), expected)

        # Ensure various dtype-related aspects of the layer are correct
        self.assertEqual(layer.dtype, 'float32')
        self.assertEqual(
            get_layer_policy.get_layer_policy(layer).name, 'mixed_float16')
        self.assertEqual(layer.v.dtype, 'float32')
        self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')

        # Loading a model always loads with a v2 Policy, even if saved with a
        # PolicyV1.
        self.assertEqual(type(model.dtype_policy), policy.Policy)
        self.assertEqual(layer.get_config()['dtype'], {
            'class_name': 'Policy',
            'config': {
                'name': 'mixed_float16'
            }
        })
Exemple #4
0
 def test_error(self):
     with self.assertRaisesRegex(
             ValueError,
             'get_policy can only be called on a layer, but got: 1'):
         get_layer_policy.get_layer_policy(1)