Пример #1
0
    def test_get_layer_policy(self):
        layer = core.Dense(4)
        self.assertEqual(
            get_layer_policy.get_layer_policy(layer).name, 'float32')

        p = policy.Policy('mixed_float16')
        layer = core.Dense(4, dtype=p)
        self.assertIs(get_layer_policy.get_layer_policy(layer), p)

        layer = core.Dense(4, dtype='float64')
        self.assertEqual(
            get_layer_policy.get_layer_policy(layer).name, 'float64')
Пример #2
0
    def _test_saving(self, model, dataset, save_format, use_regularizer):
        # Save and load model, asserting variable does not change
        save_path = os.path.join(self.get_temp_dir(), 'model')
        model.save(save_path, save_format=save_format)
        model = save.load_model(save_path)
        (layer, ) = (layer for layer in model.layers
                     if 'MultiplyLayer' in layer.__class__.__name__)
        expected = 1 - 2**-14
        if use_regularizer:
            expected -= 2**-14
        self.assertEqual(backend.eval(layer.v), expected)

        # Continue training, and assert variable is correct value
        model.fit(dataset)
        new_expected = expected - 2**-14
        if use_regularizer:
            new_expected -= 2**-14
        self.assertEqual(backend.eval(layer.v), new_expected)

        # Load saved model again, and assert variable is previous value
        model = save.load_model(save_path)
        (layer, ) = (layer for layer in model.layers
                     if 'MultiplyLayer' in layer.__class__.__name__)
        self.assertEqual(backend.eval(layer.v), expected)

        # Ensure various dtype-related aspects of the layer are correct
        self.assertEqual(layer.dtype, 'float32')
        self.assertEqual(
            get_layer_policy.get_layer_policy(layer).name, 'mixed_float16')
        self.assertEqual(layer.v.dtype, 'float32')
        self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')
Пример #3
0
 def test_mixed_policies_(self, strategy_fn):
     for dtype in 'float16', 'bfloat16':
         x = constant_op.constant([1.])
         policy_name = 'mixed_' + dtype
         with strategy_fn().scope(), policy.policy_scope(policy_name):
             layer = mp_test_util.MultiplyLayer(assert_type=dtype)
             self.assertEqual(layer.dtype, dtypes.float32)
             self.assertEqual(
                 get_layer_policy.get_layer_policy(layer).name, policy_name)
             y = layer(x)
             self.assertEqual(layer.v.dtype, dtypes.float32)
             self.assertEqual(y.dtype, dtype)
             self.assertEqual(layer.dtype, dtypes.float32)
             self.assertEqual(
                 get_layer_policy.get_layer_policy(layer).name, policy_name)
             self.evaluate(variables.global_variables_initializer())
             self.assertEqual(self.evaluate(y), 1.)
Пример #4
0
 def test_error(self):
     with self.assertRaisesRegexp(
             ValueError,
             'get_policy can only be called on a layer, but got: 1'):
         get_layer_policy.get_layer_policy(1)