def test_layer(self, layer_cls, init_args, variable_settings, input_shape, input_dtype=np.float32): args_with_layout = init_args.copy() for variable_name, variable_rank in variable_settings.items(): args_with_layout[variable_name + '_layout'] = dtensor.Layout.replicated( self.mesh, variable_rank) layer = layer_cls(**args_with_layout) # inputs = np.random.random(input_shape) inputs = np.random.randn(*input_shape).astype(input_dtype) d_inputs = dtensor.copy_to_mesh( inputs, dtensor.Layout.replicated(self.mesh, len(input_shape))) d_output = layer(d_inputs) for variable_name, variable_rank in variable_settings.items(): self.assertIsInstance(getattr(layer, variable_name), dtensor.DVariable) expected_layout = dtensor.Layout.replicated(self.mesh, d_output.shape.rank) self.assertEqual(dtensor.fetch_layout(d_output), expected_layout) # Make sure to produce same output when layout is not used tf_utils.set_random_seed(1337) layer_2 = layer_cls(**init_args) output = layer_2(inputs) self.assertAllClose(d_output, output) for variable_name, variable_rank in variable_settings.items(): self.assertNotIsInstance(getattr(layer_2, variable_name), dtensor.DVariable)
def test_random_value_initializer(self, initializer_cls, init_args): layout = dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], self.mesh) shape = (4, 4) initializer = initializer_cls(**init_args) # Make sure to raise error when keras global seed is not set. with self.assertRaisesRegex(ValueError, "set the global seed"): initializer(shape=shape, layout=layout) try: tf_utils.set_random_seed(1337) value = initializer(shape=shape, layout=layout) self.assertEqual(value.shape, shape) fetched_layout = dtensor.fetch_layout(value) self.assertEqual(layout, fetched_layout) # Make sure when same seed is set again, the new initializer should # generate same result tf_utils.set_random_seed(1337) initializer = initializer_cls(**init_args) new_value = initializer(shape=shape, layout=layout) self.assertAllClose(value, new_value) finally: # Unset the keras global generator so that it doesn't affect other # tests that need to verify the existence of global generator. backend._SEED_GENERATOR.generator = None
def test_static_value_initializer(self, initializer_cls, init_args): layout = dtensor.Layout([dtensor.UNSHARDED, dtensor.UNSHARDED], self.mesh) shape = (4, 4) initializer = initializer_cls(**init_args) value = initializer(shape=shape, layout=layout) normal_tensor_value = initializer(shape=shape) self.assertEqual(value.shape, shape) fetched_layout = dtensor.fetch_layout(value) self.assertEqual(layout, fetched_layout) self.assertAllClose(value, normal_tensor_value)
def test_conv2d_layer_with_layout(self): conv = layers.Conv2D(32, kernel_size=(3, 3), kernel_layout=self.layout_4d, bias_layout=self.layout_1d) inputs = np.random.randint(size=[10, 28, 28, 1], low=0, high=4) inputs = tf.constant(inputs, dtype=tf.float32) d_inputs = dtensor.copy_to_mesh(inputs, self.layout_4d) output = conv(d_inputs) self.assertIsInstance(conv.kernel, dtensor.DVariable) self.assertIsInstance(conv.bias, dtensor.DVariable) self.assertEqual(dtensor.fetch_layout(output), self.layout_4d) # Make sure to produce same output when layout is not used tf_utils.set_random_seed(1337) conv2 = layers.Conv2D(32, kernel_size=(3, 3)) output_2 = conv2(inputs) self.assertAllClose(output, output_2)
def test_dense_layer_with_layout(self): dense = layers.Dense(10, kernel_layout=self.layout_2d, bias_layout=self.layout_1d) inputs = np.random.randint(size=[32, 8], low=0, high=4) inputs = tf.constant(inputs, dtype=tf.float32) d_inputs = dtensor.copy_to_mesh( inputs, dtensor.Layout.replicated(self.mesh, rank=2)) output = dense(d_inputs) self.assertIsInstance(dense.kernel, dtensor.DVariable) self.assertIsInstance(dense.bias, dtensor.DVariable) expected_layout = dtensor.Layout( [dtensor.UNSHARDED, dtensor.UNSHARDED], self.mesh) self.assertEqual(dtensor.fetch_layout(output), expected_layout) # Make sure to produce same output when layout is not used tf_utils.set_random_seed(1337) dense_2 = layers.Dense(10) output_2 = dense_2(inputs) self.assertAllClose(output, output_2)