Ejemplo n.º 1
0
    def test_InverseFlow(self):
        original_flow = tk.layers.jit_compile(_MyFlow())
        flow = InverseFlow(original_flow)
        self.assertIs(flow.original_flow, original_flow)
        self.assertIs(flow.invert(), original_flow)

        flow = tk.layers.jit_compile(flow)
        self.assertEqual(flow.get_x_event_ndims(), 2)
        self.assertEqual(flow.get_y_event_ndims(), 1)
        self.assertTrue(flow.is_explicitly_invertible())

        x = T.random.randn([2, 3, 4, 1])
        expected_y = T.reshape((x - 1.) * 0.5, [2, 3, 4])
        expected_log_det = -T.full([2, 3], math.log(2.) * 4)
        input_log_det = T.random.randn([2, 3])

        flow_standard_check(self, flow, x, expected_y, expected_log_det,
                            input_log_det)

        with pytest.raises(TypeError,
                           match='`flow` must be an explicitly invertible flow'):
            _ = InverseFlow(tk.layers.Linear(5, 3))

        base_flow = _MyFlow()
        base_flow.explicitly_invertible = False
        with pytest.raises(TypeError,
                           match='`flow` must be an explicitly invertible flow'):
            _ = InverseFlow(tk.layers.jit_compile(base_flow))
Ejemplo n.º 2
0
    def test_batch_norm(self):
        eps = T.EPSILON
        for spatial_ndims in (0, 1, 2, 3):
            cls = getattr(tk.layers, ('BatchNorm' if not spatial_ndims
                                      else f'BatchNorm{spatial_ndims}d'))
            layer = cls(5, momentum=0.1, epsilon=eps)
            self.assertIn('BatchNorm', repr(layer))
            self.assertTrue(tk.layers.is_batch_norm(layer))
            layer = tk.layers.jit_compile(layer)

            # layer output
            x = T.random.randn(make_conv_shape(
                [3], 5, [6, 7, 8][:spatial_ndims]
            ))

            set_train_mode(layer)
            _ = layer(x)
            set_train_mode(layer, False)
            y = layer(x)
            set_train_mode(layer, True)
            set_eval_mode(layer)
            y2 = layer(x)

            # manually compute the expected output
            if T.backend_name == 'PyTorch':
                dst_shape = [-1] + [1] * spatial_ndims
                weight = T.reshape(layer.weight, dst_shape)
                bias = T.reshape(layer.bias, dst_shape)
                running_mean = T.reshape(layer.running_mean, dst_shape)
                running_var = T.reshape(layer.running_var, dst_shape)
                expected = (((x - running_mean) / T.sqrt(running_var + eps)) *
                            weight + bias)
            else:
                raise RuntimeError()

            # check output
            assert_allclose(y, expected, rtol=1e-4, atol=1e-6)
            assert_allclose(y2, expected, rtol=1e-4, atol=1e-6)

            # check invalid dimensions
            with pytest.raises(Exception, match='torch|only supports .d input'):
                _ = layer(
                    T.random.randn(make_conv_shape(
                        [], 5, [6, 7, 8][:spatial_ndims]
                    ))
                )
Ejemplo n.º 3
0
 def plot_samples(epoch=None):
     epoch = epoch or loop.epoch
     with tk.layers.scoped_eval_mode(vae), T.no_grad():
         logits = vae.p(n_z=100)['x'].distribution.logits
         images = T.reshape(
             T.cast(T.clip(T.nn.sigmoid(logits) * 255., 0., 255.), dtype=T.uint8),
             [-1, 28, 28],
         )
     utils.save_images_collection(
         images=T.to_numpy(images),
         filename=exp.abspath(f'plotting/{epoch}.png'),
         grid_size=(10, 10),
     )
Ejemplo n.º 4
0
    def test_call(self):
        flow = tk.layers.jit_compile(_MyFlow())
        self.assertEqual(flow.get_x_event_ndims(), 1)
        self.assertEqual(flow.get_y_event_ndims(), 2)
        self.assertEqual(flow.is_explicitly_invertible(), True)

        # test call
        x = T.random.randn([2, 3, 4])
        expected_y = T.reshape(x * 2. + 1., [2, 3, 4, 1])
        expected_log_det = T.full([2, 3], math.log(2.) * 4)
        input_log_det = T.random.randn([2, 3])

        flow_standard_check(self, flow, x, expected_y, expected_log_det,
                            input_log_det)

        # test input shape error
        with pytest.raises(Exception,
                           match='`input` is required to be at least .*d'):
            _ = flow(T.random.randn([]))
        with pytest.raises(Exception,
                           match='`input` is required to be at least .*d'):
            _ = flow(T.random.randn([3]), inverse=True)

        # test input_log_det shape error
        with pytest.raises(Exception,
                           match='The shape of `input_log_det` is not expected'):
            _ = flow(x, T.random.randn([2, 4]))
        with pytest.raises(Exception,
                           match='The shape of `input_log_det` is not expected'):
            _ = flow(expected_y, T.random.randn([2, 4]), inverse=True)

        # test output_log_det shape error
        flow = tk.layers.jit_compile(_MyBadFlow())
        with pytest.raises(Exception, match='(shape|size)'):
            _ = flow(x)
        with pytest.raises(Exception, match='(shape|size)'):
            _ = flow(x, inverse=True)
Ejemplo n.º 5
0
def check_invertible_linear(ctx,
                            spatial_ndims: int,
                            invertible_linear_factory,
                            linear_factory,
                            strict: bool,):
    batch_shape = [2]
    num_features = 4
    spatial_shape = [5, 6, 7][:spatial_ndims]
    x = T.random.randn(make_conv_shape(
        batch_shape, num_features, spatial_shape))

    # construct the layer
    flow = invertible_linear_factory(num_features, strict=strict)
    ctx.assertIn(f'num_features={num_features}', repr(flow))
    flow = tk.layers.jit_compile(flow)

    # derive the expected answer
    weight, log_det = flow.invertible_matrix(
        inverse=False, compute_log_det=True)
    linear_kwargs = {}
    if spatial_ndims > 0:
        linear_kwargs['kernel_size'] = 1
    linear = linear_factory(
        num_features, num_features,
        weight_init=T.reshape(weight, T.shape(weight) + [1] * spatial_ndims),
        use_bias=False,
        **linear_kwargs
    )
    x_flatten, front_shape = T.flatten_to_ndims(x, spatial_ndims + 2)
    expected_y = T.unflatten_from_ndims(linear(x_flatten), front_shape)
    expected_log_det = T.expand(
        T.reduce_sum(T.expand(log_det, spatial_shape)), batch_shape)

    # check the invertible layer
    flow_standard_check(ctx, flow, x, expected_y, expected_log_det,
                        T.random.randn(batch_shape))
Ejemplo n.º 6
0
    def do_check(batch_shape, scale_type, initialized, dtype):
        x = T.random.randn(make_conv_shape(batch_shape, num_features,
                                           [6, 7, 8][:spatial_ndims]),
                           dtype=dtype)

        # check construct
        flow = cls(num_features,
                   scale=scale_type,
                   initialized=initialized,
                   dtype=dtype)
        ctx.assertIn(f'num_features={num_features}', repr(flow))
        ctx.assertIn(f'axis={-(spatial_ndims + 1)}', repr(flow))
        ctx.assertIn(f'scale_type={scale_type!r}', repr(flow))
        flow = tk.layers.jit_compile(flow)

        # check initialize
        if not initialized:
            # must initialize with sufficient data
            with pytest.raises(Exception, match='with at least .* dimensions'):
                _ = flow(
                    T.random.randn(make_conv_shape([], num_features,
                                                   [6, 7, 8][:spatial_ndims]),
                                   dtype=dtype))

            # must initialize with inverse = Fale
            with pytest.raises(Exception,
                               match='`ActNorm` must be initialized with '
                               '`inverse = False`'):
                _ = flow(x, inverse=True)

            # do initialize
            y, _ = flow(x, compute_log_det=False)
            y_mean, y_var = T.calculate_mean_and_var(
                y, axis=[a for a in range(-T.rank(y), 0) if a != channel_axis])
            assert_allclose(y_mean,
                            T.zeros([num_features]),
                            rtol=1e-4,
                            atol=1e-6)
            assert_allclose(y_var,
                            T.ones([num_features]),
                            rtol=1e-4,
                            atol=1e-6)
        else:
            y, _ = flow(x, compute_log_det=False)
            assert_allclose(y, x, rtol=1e-4, atol=1e-6)

        # prepare for the expected result
        scale_obj = ExpScale() if scale_type == 'exp' else LinearScale()

        if T.IS_CHANNEL_LAST:
            aligned_shape = [num_features]
        else:
            aligned_shape = [num_features] + [1] * spatial_ndims
        bias = T.reshape(flow.bias, aligned_shape)
        pre_scale = T.reshape(flow.pre_scale, aligned_shape)

        expected_y, expected_log_det = scale_obj(x + bias,
                                                 pre_scale,
                                                 event_ndims=(spatial_ndims +
                                                              1),
                                                 compute_log_det=True)

        flow_standard_check(ctx, flow, x, expected_y, expected_log_det,
                            T.random.randn(T.shape(expected_log_det)))