def test_call_channels_first(self): """Test `call` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") with tf.device("/gpu:0"): # Default NCHW format input_shape = (128, 8, 8) data_shape = (16, ) + input_shape x = tf.random_normal(shape=data_shape) # Stride of 1 block = blocks.RevBlock(n_res=3, filters=128, strides=(1, 1), input_shape=input_shape) y_tr, y_ev = block(x, training=True), block(x, training=False) self.assertEqual(y_tr.shape, y_ev.shape) self.assertEqual(y_ev.shape, (16, 128, 8, 8)) self.assertNotAllClose(y_tr, y_ev) # Stride of 2 block = blocks.RevBlock(n_res=3, filters=128, strides=(2, 2), input_shape=input_shape) y_tr, y_ev = block(x, training=True), block(x, training=False) self.assertEqual(y_tr.shape, y_ev.shape) self.assertEqual(y_ev.shape, [16, 128, 4, 4]) self.assertNotAllClose(y_tr, y_ev)
def test_call_channels_last(self): """Test `call` function with `channels_last` data format.""" with tf.device("/cpu:0"): # NHWC format input_shape = (8, 8, 128) data_shape = (16, ) + input_shape x = tf.random_normal(shape=data_shape) # Stride 1 block = blocks.RevBlock(n_res=3, filters=128, strides=(1, 1), input_shape=input_shape, data_format="channels_last") y_tr, y_ev = block(x, training=True), block(x, training=False) self.assertEqual(y_tr.shape, y_ev.shape) self.assertEqual(y_ev.shape, (16, 8, 8, 128)) self.assertNotAllClose(y_tr, y_ev) # Stride of 2 block = blocks.RevBlock(n_res=3, filters=128, strides=(2, 2), input_shape=input_shape, data_format="channels_last") y_tr, y_ev = block(x, training=True), block(x, training=False) self.assertEqual(y_tr.shape, y_ev.shape) self.assertEqual(y_ev.shape, (16, 4, 4, 128)) self.assertNotAllClose(y_tr, y_ev)
def test_backward_grads_and_vars_channels_last(self): """Test `backward` function with `channels_last` data format.""" with tf.device("/cpu:0"): # NHWC format input_shape = (224, 224, 32) data_shape = (16, ) + input_shape x = tf.random_normal(shape=data_shape) # Stride 1 y = tf.random_normal(shape=data_shape) dy = tf.random_normal(shape=data_shape) block = blocks.RevBlock(n_res=3, filters=32, strides=(1, 1), input_shape=input_shape, data_format="channels_last") dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) self.assertEqual(dy.shape, x.shape) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) # Stride 2 y = tf.random_normal(shape=(16, 112, 112, 32)) dy = tf.random_normal(shape=(16, 112, 112, 32)) block = blocks.RevBlock(n_res=3, filters=32, strides=(2, 2), input_shape=input_shape, data_format="channels_last") dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) self.assertEqual(dy.shape, x.shape) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list))
def test_backward_grads_and_vars_channels_first(self): """Test `backward` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") with tf.device("/gpu:0"): # Default NCHW format input_shape = (32, 224, 224) data_shape = (16, ) + input_shape x = tf.random_normal(shape=data_shape) # Stride 1 y = tf.random_normal(shape=data_shape) dy = tf.random_normal(shape=data_shape) block = blocks.RevBlock(n_res=3, filters=32, strides=(1, 1), input_shape=input_shape) dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) self.assertEqual(dy.shape, x.shape) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) # Stride 2 y = tf.random_normal(shape=(16, 32, 112, 112)) dy = tf.random_normal(shape=(16, 32, 112, 112)) block = blocks.RevBlock(n_res=3, filters=32, strides=(2, 2), input_shape=input_shape) dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) self.assertEqual(dy.shape, x.shape) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list))
def test_backward_grads_and_vars_channels_first(self): """Test `backward` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") with tf.device("/gpu:0"): # Default NCHW format # Stride 1 input_shape = (128, 8, 8) data_shape = (16, ) + input_shape x = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=data_shape, dtype=tf.float64) block = blocks.RevBlock(n_res=3, filters=128, strides=(1, 1), input_shape=input_shape, fused=False, dtype=tf.float64) with tf.GradientTape() as tape: tape.watch(x) y = block(x, training=True) # Compute grads from reconstruction dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) # Compute true grads grads = tape.gradient(y, [x] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] self.assertAllClose(dx_true, dx) self.assertAllClose(dw_true, dw) self._check_grad_angle(dx_true, dx) self._check_grad_angle(dw_true, dw) # Stride 2 x = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64) block = blocks.RevBlock(n_res=3, filters=128, strides=(2, 2), input_shape=input_shape, fused=False, dtype=tf.float64) with tf.GradientTape() as tape: tape.watch(x) y = block(x, training=True) # Compute grads from reconstruction dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True) # Compute true grads grads = tape.gradient(y, [x] + vars_, output_gradients=dy) dx_true, dw_true = grads[0], grads[1:] self.assertAllClose(dx_true, dx) self.assertAllClose(dw_true, dw) self._check_grad_angle(dx_true, dx) self._check_grad_angle(dw_true, dw)
def test_backward_grads_with_nativepy(self): if not tf.test.is_gpu_available(): self.skipTest("GPU not available") input_shape = (128, 8, 8) data_shape = (16, ) + input_shape x = tf.random_normal(shape=data_shape, dtype=tf.float64) dy = tf.random_normal(shape=data_shape, dtype=tf.float64) dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) block = blocks.RevBlock(n_res=3, filters=128, strides=(1, 1), input_shape=input_shape, fused=False, dtype=tf.float64) with tf.GradientTape() as tape: tape.watch(x) x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) y1, y2 = block((x1, x2), training=True) y = tf.concat((y1, y2), axis=1) # Compute true grads dx_true = tape.gradient(y, x, output_gradients=dy) # Compute grads from reconstruction (dx1, dx2), _ = block.backward_grads(x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) dx = tf.concat((dx1, dx2), axis=1) thres = 1e-5 diff_abs = tf.reshape(abs(dx - dx_true), [-1]) assert all(diff_abs < thres)
def _construct_intermediate_blocks(self): # Precompute input shape after initial block stride = self.config.init_stride if self.config.init_max_pool: stride *= 2 if self.config.data_format == "channels_first": w, h = self.config.input_shape[1], self.config.input_shape[2] input_shape = (self.config.init_filters, w // stride, h // stride) else: w, h = self.config.input_shape[0], self.config.input_shape[1] input_shape = (w // stride, h // stride, self.config.init_filters) # Aggregate intermediate blocks block_list = tf.contrib.checkpoint.List() for i in range(self.config.n_rev_blocks): # RevBlock configurations n_res = self.config.n_res[i] filters = self.config.filters[i] if filters % 2 != 0: raise ValueError( "Number of output filters must be even to ensure" "correct partitioning of channels") stride = self.config.strides[i] strides = (self.config.strides[i], self.config.strides[i]) # Add block rev_block = blocks.RevBlock( n_res, filters, strides, input_shape, batch_norm_first=(i != 0), # Only skip on first block data_format=self.config.data_format, bottleneck=self.config.bottleneck, fused=self.config.fused, dtype=self.config.dtype) block_list.append(rev_block) # Precompute input shape for the next block if self.config.data_format == "channels_first": w, h = input_shape[1], input_shape[2] input_shape = (filters, w // stride, h // stride) else: w, h = input_shape[0], input_shape[1] input_shape = (w // stride, h // stride, filters) return block_list