def test_downsampling_non_reversible_step(self, distribution): bsz, h, w, c = 8, 32, 32, 32 filters = 64 strides = 2 input_tensor = tf.random.uniform(shape=[bsz, h, w, c]) with distribution.scope(): f = nn_blocks.ResidualInner(filters=filters // 2, strides=strides, batch_norm_first=True) g = nn_blocks.ResidualInner(filters=filters // 2, strides=1, batch_norm_first=True) test_layer = nn_blocks.ReversibleLayer(f, g) test_layer.build(input_tensor.shape) optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) @tf.function def step_fn(): with tf.GradientTape() as tape: output = test_layer(input_tensor, training=True) grads = tape.gradient(output, test_layer.trainable_variables) # Test applying gradients with optimizer works optimizer.apply_gradients( zip(grads, test_layer.trainable_variables)) return output replica_output = distribution.run(step_fn) outputs = distribution.experimental_local_results(replica_output) # Assert forward pass shape expected_output_shape = [bsz, h // strides, w // strides, filters] for output in outputs: self.assertEqual(expected_output_shape, output.shape.as_list())
def test_reversible_step(self, distribution): # Reversible layers satisfy: (a) strides = 1 (b) in_filter = out_filter bsz, h, w, c = 8, 32, 32, 32 filters = c strides = 1 input_tensor = tf.random.uniform(shape=[bsz, h, w, c]) with distribution.scope(): f = nn_blocks.ResidualInner(filters=filters // 2, strides=strides, batch_norm_first=False) g = nn_blocks.ResidualInner(filters=filters // 2, strides=1, batch_norm_first=False) test_layer = nn_blocks.ReversibleLayer(f, g) test_layer(input_tensor, training=False) # init weights optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) @tf.function def step_fn(): with tf.GradientTape() as tape: output = test_layer(input_tensor, training=True) grads = tape.gradient(output, test_layer.trainable_variables) # Test applying gradients with optimizer works optimizer.apply_gradients( zip(grads, test_layer.trainable_variables)) return output @tf.function def fwd(): test_layer(input_tensor) distribution.run(fwd) # Initialize variables prev_variables = tf.identity_n(test_layer.trainable_variables) replica_output = distribution.run(step_fn) outputs = distribution.experimental_local_results(replica_output) # Assert variables values have changed values for v0, v1 in zip(prev_variables, test_layer.trainable_variables): self.assertNotAllEqual(v0, v1) # Assert forward pass shape expected_output_shape = [bsz, h // strides, w // strides, filters] for output in outputs: self.assertEqual(expected_output_shape, output.shape.as_list())
def test_shape(self, distribution): bsz, h, w, c = 8, 32, 32, 32 filters = 64 strides = 2 input_tensor = tf.random.uniform(shape=[bsz, h, w, c]) with distribution.scope(): test_layer = nn_blocks.ResidualInner(filters, strides) output = test_layer(input_tensor) expected_output_shape = [bsz, h // strides, w // strides, filters] self.assertEqual(expected_output_shape, output.shape.as_list())