コード例 #1
0
    def test_get_layer_policy(self):
        layer = core.Dense(4)
        self.assertEqual(
            get_layer_policy.get_layer_policy(layer).name, 'float32')

        p = policy.Policy('mixed_float16')
        layer = core.Dense(4, dtype=p)
        self.assertIs(get_layer_policy.get_layer_policy(layer), p)

        layer = core.Dense(4, dtype='float64')
        self.assertEqual(
            get_layer_policy.get_layer_policy(layer).name, 'float64')
コード例 #2
0
 def test_mixed_policies_(self, strategy_fn):
     strategy = strategy_fn()
     for dtype in 'float16', 'bfloat16':
         x = constant_op.constant([1.])
         policy_name = 'mixed_' + dtype
         with strategy.scope(), policy.policy_scope(policy_name):
             layer = mp_test_util.MultiplyLayer(assert_type=dtype)
             self.assertEqual(layer.dtype, dtypes.float32)
             self.assertEqual(
                 get_layer_policy.get_layer_policy(layer).name, policy_name)
             y = layer(x)
             self.assertEqual(layer.v.dtype, dtypes.float32)
             self.assertEqual(y.dtype, dtype)
             self.assertEqual(layer.dtype_policy.name, policy_name)
             self.assertIsInstance(layer.dtype_policy, policy.Policy)
             self.assertEqual(layer.compute_dtype, dtype)
             self.assertEqual(layer.dtype, dtypes.float32)
             self.assertEqual(layer.variable_dtype, dtypes.float32)
             self.assertEqual(
                 get_layer_policy.get_layer_policy(layer).name, policy_name)
             self.evaluate(variables.global_variables_initializer())
             self.assertEqual(self.evaluate(y), 1.)
コード例 #3
0
    def _test_saving(self, model, dataset, save_format, use_regularizer):
        # Save and load model, asserting variable does not change
        save_path = os.path.join(self.get_temp_dir(), 'model')
        model.save(save_path, save_format=save_format)
        model = save.load_model(save_path)
        (layer, ) = (layer for layer in model.layers
                     if 'MultiplyLayer' in layer.__class__.__name__)
        expected = 1 - 2**-14
        if use_regularizer:
            expected -= 2 * 2**-14
        self.assertEqual(backend.eval(layer.v), expected)

        # Continue training, and assert variable is correct value
        model.fit(dataset)
        new_expected = expected - 2**-14
        if use_regularizer:
            new_expected -= 2 * 2**-14
        self.assertEqual(backend.eval(layer.v), new_expected)

        # Load saved model again, and assert variable is previous value
        model = save.load_model(save_path)
        (layer, ) = (layer for layer in model.layers
                     if 'MultiplyLayer' in layer.__class__.__name__)
        self.assertEqual(backend.eval(layer.v), expected)

        # Ensure various dtype-related aspects of the layer are correct
        self.assertEqual(layer.dtype, 'float32')
        self.assertEqual(
            get_layer_policy.get_layer_policy(layer).name, 'mixed_float16')
        self.assertEqual(layer.v.dtype, 'float32')
        self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')

        # Loading a model always loads with a v2 Policy, even if saved with a
        # PolicyV1.
        self.assertEqual(type(model.dtype_policy), policy.Policy)
        self.assertEqual(layer.get_config()['dtype'], {
            'class_name': 'Policy',
            'config': {
                'name': 'mixed_float16'
            }
        })
コード例 #4
0
 def test_error(self):
     with self.assertRaisesRegex(
             ValueError,
             'get_policy can only be called on a layer, but got: 1'):
         get_layer_policy.get_layer_policy(1)