def test_manual_gradients_correctness(self, distribution): bsz, h, w, c = 8, 32, 32, 32 filters = c strides = 1 input_tensor = tf.random.uniform(shape=[bsz, h, w, c * 4]) # bottleneck with distribution.scope(): f_manual = nn_blocks.BottleneckResidualInner( filters=filters // 2, strides=strides, batch_norm_first=False) g_manual = nn_blocks.BottleneckResidualInner( filters=filters // 2, strides=1, batch_norm_first=False) manual_grad_layer = nn_blocks.ReversibleLayer(f_manual, g_manual) manual_grad_layer(input_tensor, training=False) # init weights f_auto = nn_blocks.BottleneckResidualInner(filters=filters // 2, strides=strides, batch_norm_first=False) g_auto = nn_blocks.BottleneckResidualInner(filters=filters // 2, strides=1, batch_norm_first=False) auto_grad_layer = nn_blocks.ReversibleLayer(f_auto, g_auto, manual_grads=False) auto_grad_layer(input_tensor) # init weights # Clone all weights (tf.keras.layers.Layer has no .clone()) auto_grad_layer._f.set_weights(manual_grad_layer._f.get_weights()) auto_grad_layer._g.set_weights(manual_grad_layer._g.get_weights()) @tf.function def manual_fn(): with tf.GradientTape() as tape: output = manual_grad_layer(input_tensor, training=True) grads = tape.gradient(output, manual_grad_layer.trainable_variables) return grads @tf.function def auto_fn(): with tf.GradientTape() as tape: output = auto_grad_layer(input_tensor, training=True) grads = tape.gradient(output, auto_grad_layer.trainable_variables) return grads manual_grads = distribution.run(manual_fn) auto_grads = distribution.run(auto_fn) # Assert gradients calculated manually are close to that from autograd for manual_grad, auto_grad in zip(manual_grads, auto_grads): self.assertAllClose( distribution.experimental_local_results(manual_grad), distribution.experimental_local_results(auto_grad), atol=5e-3, rtol=5e-3) # Verify that BN moving mean and variance is correct. for manual_var, auto_var in zip( manual_grad_layer.non_trainable_variables, auto_grad_layer.non_trainable_variables): self.assertAllClose(manual_var, auto_var)
def test_shape(self, distribution): bsz, h, w, c = 8, 32, 32, 32 filters = 64 strides = 2 input_tensor = tf.random.uniform(shape=[bsz, h, w, c]) with distribution.scope(): test_layer = nn_blocks.BottleneckResidualInner(filters, strides) output = test_layer(input_tensor) expected_output_shape = [bsz, h // strides, w // strides, filters * 4] self.assertEqual(expected_output_shape, output.shape.as_list())