def test_get_layer_policy(self): layer = core.Dense(4) self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float32') p = policy.Policy('mixed_float16') layer = core.Dense(4, dtype=p) self.assertIs(get_layer_policy.get_layer_policy(layer), p) layer = core.Dense(4, dtype='float64') self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float64')
def test_mixed_policies_(self, strategy_fn): strategy = strategy_fn() for dtype in 'float16', 'bfloat16': x = tf.constant([1.]) policy_name = 'mixed_' + dtype with strategy.scope(), policy.policy_scope(policy_name): layer = mp_test_util.MultiplyLayer(assert_type=dtype) self.assertEqual(layer.dtype, tf.float32) self.assertEqual(get_layer_policy.get_layer_policy(layer).name, policy_name) y = layer(x) self.assertEqual(layer.v.dtype, tf.float32) self.assertEqual(y.dtype, dtype) self.assertEqual(layer.dtype_policy.name, policy_name) self.assertIsInstance(layer.dtype_policy, policy.Policy) self.assertEqual(layer.compute_dtype, dtype) self.assertEqual(layer.dtype, tf.float32) self.assertEqual(layer.variable_dtype, tf.float32) self.assertEqual(get_layer_policy.get_layer_policy(layer).name, policy_name) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual(self.evaluate(y), 1.)
def _test_saving(self, model, dataset, save_format, use_regularizer): # Save and load model, asserting variable does not change save_path = os.path.join(self.get_temp_dir(), 'model') model.save(save_path, save_format=save_format) model = save.load_model(save_path) (layer, ) = (layer for layer in model.layers if 'MultiplyLayer' in layer.__class__.__name__) expected = 1 - 2**-14 if use_regularizer: expected -= 2 * 2**-14 self.assertEqual(backend.eval(layer.v), expected) # Continue training, and assert variable is correct value model.fit(dataset) new_expected = expected - 2**-14 if use_regularizer: new_expected -= 2 * 2**-14 self.assertEqual(backend.eval(layer.v), new_expected) # Load saved model again, and assert variable is previous value model = save.load_model(save_path) (layer, ) = (layer for layer in model.layers if 'MultiplyLayer' in layer.__class__.__name__) self.assertEqual(backend.eval(layer.v), expected) # Ensure various dtype-related aspects of the layer are correct self.assertEqual(layer.dtype, 'float32') self.assertEqual( get_layer_policy.get_layer_policy(layer).name, 'mixed_float16') self.assertEqual(layer.v.dtype, 'float32') self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16') # Loading a model always loads with a v2 Policy, even if saved with a # PolicyV1. self.assertEqual(type(model.dtype_policy), policy.Policy) self.assertEqual(layer.get_config()['dtype'], { 'class_name': 'Policy', 'config': { 'name': 'mixed_float16' } })
def test_error(self): with self.assertRaisesRegex( ValueError, 'get_policy can only be called on a layer, but got: 1'): get_layer_policy.get_layer_policy(1)