def test_manual_gradients_correctness(self, distribution): bsz, h, w, c = 8, 32, 32, 32 filters = c strides = 1 input_tensor = tf.random.uniform(shape=[bsz, h, w, c * 4]) # bottleneck with distribution.scope(): f_manual = nn_blocks.BottleneckResidualInner( filters=filters // 2, strides=strides, batch_norm_first=False) g_manual = nn_blocks.BottleneckResidualInner( filters=filters // 2, strides=1, batch_norm_first=False) manual_grad_layer = nn_blocks.ReversibleLayer(f_manual, g_manual) manual_grad_layer(input_tensor, training=False) # init weights f_auto = nn_blocks.BottleneckResidualInner(filters=filters // 2, strides=strides, batch_norm_first=False) g_auto = nn_blocks.BottleneckResidualInner(filters=filters // 2, strides=1, batch_norm_first=False) auto_grad_layer = nn_blocks.ReversibleLayer(f_auto, g_auto, manual_grads=False) auto_grad_layer(input_tensor) # init weights # Clone all weights (tf.keras.layers.Layer has no .clone()) auto_grad_layer._f.set_weights(manual_grad_layer._f.get_weights()) auto_grad_layer._g.set_weights(manual_grad_layer._g.get_weights()) @tf.function def manual_fn(): with tf.GradientTape() as tape: output = manual_grad_layer(input_tensor, training=True) grads = tape.gradient(output, manual_grad_layer.trainable_variables) return grads @tf.function def auto_fn(): with tf.GradientTape() as tape: output = auto_grad_layer(input_tensor, training=True) grads = tape.gradient(output, auto_grad_layer.trainable_variables) return grads manual_grads = distribution.run(manual_fn) auto_grads = distribution.run(auto_fn) # Assert gradients calculated manually are close to that from autograd for manual_grad, auto_grad in zip(manual_grads, auto_grads): self.assertAllClose( distribution.experimental_local_results(manual_grad), distribution.experimental_local_results(auto_grad), atol=5e-3, rtol=5e-3) # Verify that BN moving mean and variance is correct. for manual_var, auto_var in zip( manual_grad_layer.non_trainable_variables, auto_grad_layer.non_trainable_variables): self.assertAllClose(manual_var, auto_var)
def test_downsampling_non_reversible_step(self, distribution): bsz, h, w, c = 8, 32, 32, 32 filters = 64 strides = 2 input_tensor = tf.random.uniform(shape=[bsz, h, w, c]) with distribution.scope(): f = nn_blocks.ResidualInner(filters=filters // 2, strides=strides, batch_norm_first=True) g = nn_blocks.ResidualInner(filters=filters // 2, strides=1, batch_norm_first=True) test_layer = nn_blocks.ReversibleLayer(f, g) test_layer.build(input_tensor.shape) optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) @tf.function def step_fn(): with tf.GradientTape() as tape: output = test_layer(input_tensor, training=True) grads = tape.gradient(output, test_layer.trainable_variables) # Test applying gradients with optimizer works optimizer.apply_gradients( zip(grads, test_layer.trainable_variables)) return output replica_output = distribution.run(step_fn) outputs = distribution.experimental_local_results(replica_output) # Assert forward pass shape expected_output_shape = [bsz, h // strides, w // strides, filters] for output in outputs: self.assertEqual(expected_output_shape, output.shape.as_list())
def test_reversible_step(self, distribution): # Reversible layers satisfy: (a) strides = 1 (b) in_filter = out_filter bsz, h, w, c = 8, 32, 32, 32 filters = c strides = 1 input_tensor = tf.random.uniform(shape=[bsz, h, w, c]) with distribution.scope(): f = nn_blocks.ResidualInner(filters=filters // 2, strides=strides, batch_norm_first=False) g = nn_blocks.ResidualInner(filters=filters // 2, strides=1, batch_norm_first=False) test_layer = nn_blocks.ReversibleLayer(f, g) test_layer(input_tensor, training=False) # init weights optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) @tf.function def step_fn(): with tf.GradientTape() as tape: output = test_layer(input_tensor, training=True) grads = tape.gradient(output, test_layer.trainable_variables) # Test applying gradients with optimizer works optimizer.apply_gradients( zip(grads, test_layer.trainable_variables)) return output @tf.function def fwd(): test_layer(input_tensor) distribution.run(fwd) # Initialize variables prev_variables = tf.identity_n(test_layer.trainable_variables) replica_output = distribution.run(step_fn) outputs = distribution.experimental_local_results(replica_output) # Assert variables values have changed values for v0, v1 in zip(prev_variables, test_layer.trainable_variables): self.assertNotAllEqual(v0, v1) # Assert forward pass shape expected_output_shape = [bsz, h // strides, w // strides, filters] for output in outputs: self.assertEqual(expected_output_shape, output.shape.as_list())
def _block_group(self, inputs: tf.Tensor, filters: int, strides: int, inner_block_fn: Callable[..., tf.keras.layers.Layer], block_repeats: int, batch_norm_first: bool, name: str = 'revblock_group') -> tf.Tensor: """Creates one reversible block for RevNet model. Args: inputs: A `tf.Tensor` of size `[batch, channels, height, width]`. filters: An `int` number of filters for the first convolution of the layer. strides: An `int` stride to use for the first convolution of the layer. If greater than 1, this block group will downsample the input. inner_block_fn: Either `nn_blocks.ResidualInner` or `nn_blocks.BottleneckResidualInner`. block_repeats: An `int` number of blocks contained in this block group. batch_norm_first: A `bool` that specifies whether to apply BatchNormalization and activation layer before feeding into convolution layers. name: A `str` name for the block. Returns: The output `tf.Tensor` of the block layer. """ x = inputs for i in range(block_repeats): is_first_block = i == 0 # Only first residual layer in block gets downsampled curr_strides = strides if is_first_block else 1 f = inner_block_fn(filters=filters // 2, strides=curr_strides, batch_norm_first=batch_norm_first and is_first_block, kernel_regularizer=self._kernel_regularizer) g = inner_block_fn(filters=filters // 2, strides=1, batch_norm_first=batch_norm_first and is_first_block, kernel_regularizer=self._kernel_regularizer) x = nn_blocks.ReversibleLayer(f, g)(x) return tf.identity(x, name=name)