Beispiel #1
0
    def test_call_channels_first(self):
        """Test `call` function with `channels_first` data format."""
        if not tf.test.is_gpu_available():
            self.skipTest("GPU not available")

        with tf.device("/gpu:0"):  # Default NCHW format
            input_shape = (128, 8, 8)
            data_shape = (16, ) + input_shape
            x = tf.random_normal(shape=data_shape)

            # Stride of 1
            block = blocks.RevBlock(n_res=3,
                                    filters=128,
                                    strides=(1, 1),
                                    input_shape=input_shape)
            y_tr, y_ev = block(x, training=True), block(x, training=False)
            self.assertEqual(y_tr.shape, y_ev.shape)
            self.assertEqual(y_ev.shape, (16, 128, 8, 8))
            self.assertNotAllClose(y_tr, y_ev)

            # Stride of 2
            block = blocks.RevBlock(n_res=3,
                                    filters=128,
                                    strides=(2, 2),
                                    input_shape=input_shape)
            y_tr, y_ev = block(x, training=True), block(x, training=False)
            self.assertEqual(y_tr.shape, y_ev.shape)
            self.assertEqual(y_ev.shape, [16, 128, 4, 4])
            self.assertNotAllClose(y_tr, y_ev)
Beispiel #2
0
    def test_call_channels_last(self):
        """Test `call` function with `channels_last` data format."""
        with tf.device("/cpu:0"):  # NHWC format
            input_shape = (8, 8, 128)
            data_shape = (16, ) + input_shape
            x = tf.random_normal(shape=data_shape)

            # Stride 1
            block = blocks.RevBlock(n_res=3,
                                    filters=128,
                                    strides=(1, 1),
                                    input_shape=input_shape,
                                    data_format="channels_last")
            y_tr, y_ev = block(x, training=True), block(x, training=False)
            self.assertEqual(y_tr.shape, y_ev.shape)
            self.assertEqual(y_ev.shape, (16, 8, 8, 128))
            self.assertNotAllClose(y_tr, y_ev)

            # Stride of 2
            block = blocks.RevBlock(n_res=3,
                                    filters=128,
                                    strides=(2, 2),
                                    input_shape=input_shape,
                                    data_format="channels_last")
            y_tr, y_ev = block(x, training=True), block(x, training=False)
            self.assertEqual(y_tr.shape, y_ev.shape)
            self.assertEqual(y_ev.shape, (16, 4, 4, 128))
            self.assertNotAllClose(y_tr, y_ev)
Beispiel #3
0
    def test_backward_grads_and_vars_channels_last(self):
        """Test `backward` function with `channels_last` data format."""
        with tf.device("/cpu:0"):  # NHWC format
            input_shape = (224, 224, 32)
            data_shape = (16, ) + input_shape
            x = tf.random_normal(shape=data_shape)

            # Stride 1
            y = tf.random_normal(shape=data_shape)
            dy = tf.random_normal(shape=data_shape)
            block = blocks.RevBlock(n_res=3,
                                    filters=32,
                                    strides=(1, 1),
                                    input_shape=input_shape,
                                    data_format="channels_last")
            dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
            self.assertEqual(dy.shape, x.shape)
            self.assertTrue(isinstance(grads, list))
            self.assertTrue(isinstance(vars_, list))

            # Stride 2
            y = tf.random_normal(shape=(16, 112, 112, 32))
            dy = tf.random_normal(shape=(16, 112, 112, 32))
            block = blocks.RevBlock(n_res=3,
                                    filters=32,
                                    strides=(2, 2),
                                    input_shape=input_shape,
                                    data_format="channels_last")
            dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
            self.assertEqual(dy.shape, x.shape)
            self.assertTrue(isinstance(grads, list))
            self.assertTrue(isinstance(vars_, list))
Beispiel #4
0
    def test_backward_grads_and_vars_channels_first(self):
        """Test `backward` function with `channels_first` data format."""
        if not tf.test.is_gpu_available():
            self.skipTest("GPU not available")

        with tf.device("/gpu:0"):  # Default NCHW format
            input_shape = (32, 224, 224)
            data_shape = (16, ) + input_shape
            x = tf.random_normal(shape=data_shape)

            # Stride 1
            y = tf.random_normal(shape=data_shape)
            dy = tf.random_normal(shape=data_shape)
            block = blocks.RevBlock(n_res=3,
                                    filters=32,
                                    strides=(1, 1),
                                    input_shape=input_shape)
            dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
            self.assertEqual(dy.shape, x.shape)
            self.assertTrue(isinstance(grads, list))
            self.assertTrue(isinstance(vars_, list))

            # Stride 2
            y = tf.random_normal(shape=(16, 32, 112, 112))
            dy = tf.random_normal(shape=(16, 32, 112, 112))
            block = blocks.RevBlock(n_res=3,
                                    filters=32,
                                    strides=(2, 2),
                                    input_shape=input_shape)
            dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy)
            self.assertEqual(dy.shape, x.shape)
            self.assertTrue(isinstance(grads, list))
            self.assertTrue(isinstance(vars_, list))
    def test_backward_grads_and_vars_channels_first(self):
        """Test `backward` function with `channels_first` data format."""
        if not tf.test.is_gpu_available():
            self.skipTest("GPU not available")

        with tf.device("/gpu:0"):  # Default NCHW format
            # Stride 1
            input_shape = (128, 8, 8)
            data_shape = (16, ) + input_shape
            x = tf.random_normal(shape=data_shape, dtype=tf.float64)
            dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
            block = blocks.RevBlock(n_res=3,
                                    filters=128,
                                    strides=(1, 1),
                                    input_shape=input_shape,
                                    fused=False,
                                    dtype=tf.float64)
            with tf.GradientTape() as tape:
                tape.watch(x)
                y = block(x, training=True)
            # Compute grads from reconstruction
            dx, dw, vars_ = block.backward_grads_and_vars(x,
                                                          y,
                                                          dy,
                                                          training=True)
            # Compute true grads
            grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
            dx_true, dw_true = grads[0], grads[1:]
            self.assertAllClose(dx_true, dx)
            self.assertAllClose(dw_true, dw)
            self._check_grad_angle(dx_true, dx)
            self._check_grad_angle(dw_true, dw)

            # Stride 2
            x = tf.random_normal(shape=data_shape, dtype=tf.float64)
            dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64)
            block = blocks.RevBlock(n_res=3,
                                    filters=128,
                                    strides=(2, 2),
                                    input_shape=input_shape,
                                    fused=False,
                                    dtype=tf.float64)
            with tf.GradientTape() as tape:
                tape.watch(x)
                y = block(x, training=True)
            # Compute grads from reconstruction
            dx, dw, vars_ = block.backward_grads_and_vars(x,
                                                          y,
                                                          dy,
                                                          training=True)
            # Compute true grads
            grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
            dx_true, dw_true = grads[0], grads[1:]
            self.assertAllClose(dx_true, dx)
            self.assertAllClose(dw_true, dw)
            self._check_grad_angle(dx_true, dx)
            self._check_grad_angle(dw_true, dw)
Beispiel #6
0
    def test_backward_grads_with_nativepy(self):
        if not tf.test.is_gpu_available():
            self.skipTest("GPU not available")

        input_shape = (128, 8, 8)
        data_shape = (16, ) + input_shape
        x = tf.random_normal(shape=data_shape, dtype=tf.float64)
        dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
        dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1)
        block = blocks.RevBlock(n_res=3,
                                filters=128,
                                strides=(1, 1),
                                input_shape=input_shape,
                                fused=False,
                                dtype=tf.float64)
        with tf.GradientTape() as tape:
            tape.watch(x)
            x1, x2 = tf.split(x, num_or_size_splits=2, axis=1)
            y1, y2 = block((x1, x2), training=True)
            y = tf.concat((y1, y2), axis=1)

        # Compute true grads
        dx_true = tape.gradient(y, x, output_gradients=dy)

        # Compute grads from reconstruction
        (dx1, dx2), _ = block.backward_grads(x=(x1, x2),
                                             y=(y1, y2),
                                             dy=(dy1, dy2),
                                             training=True)
        dx = tf.concat((dx1, dx2), axis=1)

        thres = 1e-5
        diff_abs = tf.reshape(abs(dx - dx_true), [-1])
        assert all(diff_abs < thres)
Beispiel #7
0
    def _construct_intermediate_blocks(self):
        # Precompute input shape after initial block
        stride = self.config.init_stride
        if self.config.init_max_pool:
            stride *= 2
        if self.config.data_format == "channels_first":
            w, h = self.config.input_shape[1], self.config.input_shape[2]
            input_shape = (self.config.init_filters, w // stride, h // stride)
        else:
            w, h = self.config.input_shape[0], self.config.input_shape[1]
            input_shape = (w // stride, h // stride, self.config.init_filters)

        # Aggregate intermediate blocks
        block_list = tf.contrib.checkpoint.List()
        for i in range(self.config.n_rev_blocks):
            # RevBlock configurations
            n_res = self.config.n_res[i]
            filters = self.config.filters[i]
            if filters % 2 != 0:
                raise ValueError(
                    "Number of output filters must be even to ensure"
                    "correct partitioning of channels")
            stride = self.config.strides[i]
            strides = (self.config.strides[i], self.config.strides[i])

            # Add block
            rev_block = blocks.RevBlock(
                n_res,
                filters,
                strides,
                input_shape,
                batch_norm_first=(i != 0),  # Only skip on first block
                data_format=self.config.data_format,
                bottleneck=self.config.bottleneck,
                fused=self.config.fused,
                dtype=self.config.dtype)
            block_list.append(rev_block)

            # Precompute input shape for the next block
            if self.config.data_format == "channels_first":
                w, h = input_shape[1], input_shape[2]
                input_shape = (filters, w // stride, h // stride)
            else:
                w, h = input_shape[0], input_shape[1]
                input_shape = (w // stride, h // stride, filters)

        return block_list