Example #1
0
 def test_ConcatRecorder(self):
     arrays = [T.random.randn([2, 4]), T.random.randn([3, 4])]
     standard_recorder_check(
         self,
         ConcatRecorder(),
         arrays,
         T.concat(arrays, axis=0),
     )
Example #2
0
    def test_push_to_stack(self):
        rec1 = RecorderManager({
            'name1': ConcatRecorder,
            'name2': ConcatRecorder,
        })
        rec2 = RecorderManager({
            'name1': ConcatRecorder,
        })
        tensors = {
            'name1': [T.random.randn([2, 4]),
                      T.random.randn([3, 4])],
            'name2': [T.random.randn([5]),
                      T.random.randn([6])],
        }
        answers = {key: T.concat(tensors[key], axis=0) for key in tensors}

        def populate():
            rec1.clear()
            rec2.clear()
            for name in tensors:
                for t in tensors[name]:
                    record_tensor(name, t)

        # test empty
        self.assertFalse(has_recorder_manager())
        populate()
        self.assertEqual(rec1.get_all(), {})
        self.assertEqual(rec2.get_all(), {})

        def g(records, names):
            self.assertEqual(len(records), len(names))
            self.assertEqual(sorted(names), sorted(records))
            for n in names:
                assert_allclose(records[n], answers[n])

        # test one recorder
        self.assertFalse(has_recorder_manager())
        with rec1.push_to_stack() as rm:
            self.assertIs(rm, rec1)
            self.assertTrue(has_recorder_manager())
            populate()
        self.assertFalse(has_recorder_manager())
        g(rec1.get_all(), ['name1', 'name2'])
        g(rec2.get_all(), [])

        # test two recorders
        self.assertFalse(has_recorder_manager())
        with rec1.push_to_stack(), rec2.push_to_stack():
            self.assertTrue(has_recorder_manager())
            populate()
        self.assertFalse(has_recorder_manager())
        g(rec1.get_all(), ['name1', 'name2'])
        g(rec2.get_all(), ['name1'])
Example #3
0
    def do_check(secondary, scale_type):
        x = T.random.randn(
            make_conv_shape(batch_shape, num_features, [6, 7,
                                                        8][:spatial_ndims]))
        n1, n2 = (num_features // 2), (num_features - num_features // 2)

        # construct the instance
        shift_and_pre_scale = (shift_and_pre_scale_2
                               if secondary else shift_and_pre_scale_1)
        flow = cls(shift_and_pre_scale,
                   scale=scale_type,
                   secondary=secondary,
                   sigmoid_scale_bias=sigmoid_scale_bias)
        ctx.assertIn(f'secondary={secondary}', repr(flow))
        flow = tk.layers.jit_compile(flow)

        # obtain the expected output
        channel_axis = get_channel_axis(spatial_ndims)
        x1, x2 = T.split(x, [n1, n2], axis=channel_axis)
        if secondary:
            x1, x2 = x2, x1

        y1 = x1
        shift, pre_scale = shift_and_pre_scale(x1)
        if scale_type == 'exp' or scale_type is ExpScale:
            scale = ExpScale()
        elif scale_type == 'sigmoid' or scale_type is SigmoidScale:
            scale = SigmoidScale(pre_scale_bias=sigmoid_scale_bias)
        elif scale_type == 'linear' or scale_type is LinearScale:
            scale = LinearScale()
        elif isinstance(scale_type,
                        Scale) or tk.layers.is_jit_layer(scale_type):
            scale = scale_type
        else:
            raise ValueError(f'Invalid value for `scale`: {scale_type}')
        y2, log_det = scale(x2 + shift,
                            pre_scale,
                            event_ndims=spatial_ndims + 1,
                            compute_log_det=True)

        if secondary:
            y1, y2 = y2, y1
        expected_y = T.concat([y1, y2], axis=channel_axis)
        expected_log_det = log_det

        # now check the flow
        flow_standard_check(ctx, flow, x, expected_y, expected_log_det,
                            T.random.randn(batch_shape))
Example #4
0
    def test_record_and_clear(self):
        rm = RecorderManager(default_factory=ConcatRecorder)
        tensors = {
            'name1': [T.random.randn([2, 4]),
                      T.random.randn([3, 4])],
            'name2': [T.random.randn([5]),
                      T.random.randn([6])],
        }
        answers = {key: T.concat(tensors[key], axis=0) for key in tensors}

        def f():
            for name in tensors:
                self.assertIsNone(rm.get(name))
            self.assertEqual(list(rm.iter_all()), [])
            self.assertEqual(rm.get_all(), {})

            for name in tensors:
                for t in tensors[name]:
                    rm.record(name, t)
            for name in tensors:
                assert_allclose(rm.get(name), answers[name])

            def g(records, names):
                if not isinstance(records, dict):
                    records = {k: v for k, v in records}
                self.assertEqual(len(records), len(names))
                self.assertEqual(sorted(names), sorted(records))
                for n in names:
                    assert_allclose(records[n], answers[n])

            g(rm.iter_all(), ['name1', 'name2'])
            g(rm.get_all(), ['name1', 'name2'])

            filter_ = lambda n: n == 'name1'
            g(rm.iter_all(filter_), ['name1'])
            g(rm.get_all(filter_), ['name1'])

        f()
        rm.clear()
        f()
Example #5
0
            def f(modules, self_module, self_weight, use_bias, normalizer,
                  activation, merge_mode):
                if n_partitions == 1:
                    adj = in_adj[0]
                else:
                    adj = in_adj[:len(modules)]

                out_shape = list(add_out_shape)
                if merge_mode == 'concat':
                    out_shape[feature_axis] *= len(modules) + int(
                        self_module is not None)

                bias_store = (SimpleParamStore(out_shape,
                                               initializer=tk.init.normal)
                              if use_bias else None)
                layer_kwargs = dict(self_module=self_module,
                                    self_weight=self_weight,
                                    bias_store=bias_store,
                                    normalizer=normalizer,
                                    activation=activation,
                                    merge_mode=merge_mode)

                layer = jit_compile(
                    cls(module=modules[0], **layer_kwargs) if n_partitions ==
                    1 else cls(modules=modules, **layer_kwargs))
                if isinstance(activation, type):
                    activation_layer = activation()
                else:
                    activation_layer = activation

                for x in inputs:
                    # test errors
                    if len(modules) > 1:
                        with pytest.raises(
                                Exception,
                                match=r'`adj` is expected to have .* element'
                                r'\(s\), but got .*'):
                            _ = layer(x, in_adj[:len(modules) - 1])

                    if T.rank(x) == value_ndims + 1:
                        with pytest.raises(
                                Exception,
                                match='`input` is expected to be at least .*d'
                        ):
                            _ = layer(x[0], adj)

                    # obtain valid output
                    y = layer(x, adj)
                    self.assertEqual(T.shape(y),
                                     T.shape(x)[:-value_ndims] + out_shape)

                    # compute the expected output
                    def g(m, x):
                        m_out, m_front = T.flatten_to_ndims(x, value_ndims + 1)
                        m_out = m(m_out)
                        m_out = T.unflatten_from_ndims(m_out, m_front)
                        return m_out

                    outs = []
                    for m, a in zip(modules, in_adj):
                        m_out = T.as_tensor(
                            np.reshape(
                                np.dot(T.sparse.to_numpy(a),
                                       T.to_numpy(x).reshape([50, -1])),
                                x.shape))
                        outs.append(g(m, m_out))

                    if self_module is not None:
                        outs.append(g(self_module, x))

                    if merge_mode == 'add':
                        out = T.add_n(outs)
                    elif merge_mode == 'concat':
                        out = T.concat(outs, axis=feature_axis)

                    if bias_store is not None:
                        out = out + bias_store()
                    if normalizer is not None:
                        out = normalizer(out)
                    if activation is not None:
                        out = activation_layer(out)

                    # assert the output is expected
                    assert_allclose_(y, out)
Example #6
0
def check_split_flow(ctx, spatial_ndims: int, num_features: int, cls,
                     x_sections, left, right, **kwargs):
    batch_shape = [2]
    y_sections = kwargs.get("y_sections", x_sections)
    x_axis = kwargs.get("x_axis", get_channel_axis(spatial_ndims))
    y_axis = kwargs.get("y_axis", x_axis)

    x = T.random.randn(
        make_conv_shape(batch_shape, num_features, [6, 7, 8][:spatial_ndims]))
    input_log_det = T.random.randn(batch_shape)

    # without right
    if y_axis == x_axis:
        flow = cls(x_sections, left, None, **kwargs)
        ctx.assertIn(f'x_sections={x_sections}', repr(flow))
        ctx.assertIn(f'y_sections={y_sections}', repr(flow))
        ctx.assertIn(f'x_axis={x_axis}', repr(flow))
        ctx.assertIn(f'y_axis={y_axis}', repr(flow))
        flow = tk.layers.jit_compile(flow)

        x1, x2 = T.split(x, x_sections, axis=x_axis)
        y1, expected_log_det = left(x1, compute_log_det=True)
        y2 = x2
        expected_y = T.concat([y1, y2], axis=y_axis)

        flow_standard_check(ctx, flow, x, expected_y, expected_log_det,
                            input_log_det)

    # with right
    flow = cls(x_sections, left, right, **kwargs)
    flow = tk.layers.jit_compile(flow)

    x1, x2 = T.split(x, x_sections, axis=x_axis)
    y1, expected_log_det = left(x1, compute_log_det=True)
    y2, expected_log_det = right(x2,
                                 input_log_det=expected_log_det,
                                 compute_log_det=True)
    expected_y = T.concat([y1, y2], axis=y_axis)

    flow_standard_check(ctx, flow, x, expected_y, expected_log_det,
                        input_log_det)

    # test argument error
    with pytest.raises(ValueError,
                       match='`x_sections` must be a sequence of '
                       'two positive integers'):
        _ = cls([1, 2, 3], left)

    with pytest.raises(ValueError,
                       match='`x_sections` must be a sequence of '
                       'two positive integers'):
        _ = cls([-1, 2], left)

    with pytest.raises(ValueError,
                       match='`y_sections` must be None or a sequence of '
                       'two positive integers'):
        _ = cls([2, 3], left, right, y_sections=[1, 2, 3])

    with pytest.raises(ValueError,
                       match='`y_sections` must be None or a sequence of '
                       'two positive integers'):
        _ = cls([2, 3], left, right, y_sections=[-1, 2])

    with pytest.raises(TypeError, match='`left` is not a flow'):
        _ = cls([2, 3], tk.layers.Linear(2, 3))

    with pytest.raises(TypeError, match='`right` is not a flow'):
        _ = cls([2, 3], left, tk.layers.Linear(2, 3))
Example #7
0
    def test_embedding(self):
        n_channels = 3
        n_embeddings = n_samples

        for spatial_ndims in (0, 1, 2, 3):
            w_shape = make_conv_shape([], n_channels, [4, 5,
                                                       6][:spatial_ndims])
            w_size = int(np.prod(w_shape))

            layer = getattr(tk.layers,
                            (f'Embedding{spatial_ndims}d' if spatial_ndims > 0
                             else 'Embedding'))(n_embeddings, w_shape)
            weight = layer.weight

            # check the weight
            self.assertEqual(T.shape(weight), [n_embeddings] + w_shape)
            reduce_axis = list(range(len(w_shape) + 1))
            reduce_axis.pop(-1 if T.IS_CHANNEL_LAST else 1)
            w_mean = np.average(T.to_numpy(weight), axis=tuple(reduce_axis))
            np.testing.assert_array_less(
                w_mean, 3. / np.sqrt(n_samples * w_size / n_channels))

            # check the output
            layer = jit_compile(layer)
            weight_array = T.to_numpy(weight)

            for in_shape in ([7], [7, 8]):
                indices = T.random.randint(0, n_samples, in_shape)
                indices = T.concat([indices, indices[:3]], axis=0)

                # check the output
                output = layer(indices)
                assert_allclose(output, T.embedding(weight, indices))

                # check the grad
                if spatial_ndims in (0, 1):
                    out_sum = T.reduce_sum(output**2)
                    [grad] = T.grad([out_sum], [weight])
                    expected_grad = np.zeros(T.shape(weight))
                    for idx in T.to_numpy(indices).reshape([-1]):
                        expected_grad[idx] += 2. * weight_array[idx]
                    assert_allclose(grad, expected_grad)

            # test the constructor error
            if spatial_ndims > 0:
                with pytest.raises(
                        ValueError,
                        match=f'`embedding_size` must be a int list '
                        f'with {spatial_ndims + 1} elements'):
                    _ = getattr(tk.layers,
                                f'Embedding{spatial_ndims}d')(n_embeddings,
                                                              w_shape[:-1])

        # test no grad
        layer = Embedding(n_embeddings, n_channels, freeze=True)
        weight = layer.weight
        self.assertEqual(T.shape(weight), [n_embeddings, n_channels])

        layer = jit_compile(layer)
        indices = T.random.randint(0, n_samples, [7, 8])
        output = layer(indices)
        assert_allclose(output, T.embedding(weight, indices))

        out_sum = T.reduce_sum(output**2)
        try:
            [grad] = T.grad([out_sum], [weight])
        except Exception:
            pass
        else:
            self.assertTrue(T.is_null_grad(weight, grad))

        # test errors
        with pytest.raises(ValueError,
                           match='`embedding_size` must not be empty'):
            _ = Embedding(n_embeddings, [])