Example #1
0
 def check_x(layer):
     y = layer(x)
     y_mean, y_var = T.calculate_mean_and_var(
         y, axis=[-T.rank(x)] + get_spatial_axis(spatial_ndims))
     if use_bias:
         assert_allclose(y_mean,
                         T.zeros_like(y_mean),
                         atol=1e-6,
                         rtol=1e-4)
     assert_allclose(y_var,
                     T.ones_like(y_var),
                     atol=1e-6,
                     rtol=1e-4)
Example #2
0
 def _transform(self,
                input: Tensor,
                input_log_det: Optional[Tensor],
                inverse: bool,
                compute_log_det: bool) -> Tuple[Tensor, Optional[Tensor]]:
     output = input
     output_log_det = input_log_det
     if compute_log_det:
         if output_log_det is None:
             output_log_det = zeros_like(output)
         else:
             output_log_det = input_log_det
     return output, output_log_det
Example #3
0
 def check_x(layer):
     y = layer(x)
     y_mean, y_var = T.calculate_mean_and_var(y,
                                              axis=T.int_range(
                                                  -T.rank(x),
                                                  -1))
     if use_bias:
         assert_allclose(y_mean,
                         T.zeros_like(y_mean),
                         atol=1e-6,
                         rtol=1e-4)
     assert_allclose(y_var,
                     T.ones_like(y_var),
                     atol=1e-6,
                     rtol=1e-4)
Example #4
0
    def test_construct(self):
        mean = np.random.randn(3, 4)
        logstd = np.random.randn(2, 3, 4)
        std = np.exp(logstd)

        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype=dtype)
            std_t = T.as_tensor(std, dtype=dtype)
            logstd_t = T.as_tensor(logstd, dtype=dtype)
            mutual_params = {'std': std_t, 'logstd': logstd_t}

            # construct from mean & std/logstd
            for key, val in mutual_params.items():
                other_key = [k for k in mutual_params if k != key][0]
                normal = _MyBaseNormal(mean=mean_t,
                                       event_ndims=1,
                                       **{key: val})
                self.assertEqual(normal.continuous, True)
                self.assertEqual(normal.reparameterized, True)
                self.assertEqual(normal.min_event_ndims, 0)
                self.assertEqual(normal.event_ndims, 1)
                self.assertIs(normal.mean, mean_t)
                self.assertIs(getattr(normal, key), val)
                assert_allclose(getattr(normal, other_key),
                                mutual_params[other_key],
                                rtol=1e-4)
                self.assertEqual(normal._mutual_params, {key: val})

                # mean and std/logstd must have the same dtype
                for other_dtype in float_dtypes:
                    if other_dtype != dtype:
                        other_val = T.cast(val, other_dtype)
                        with pytest.raises(ValueError,
                                           match=f'`{key}.dtype` != `mean.'
                                           f'dtype`: {other_dtype} vs '
                                           f'{dtype}'):
                            _ = _MyBaseNormal(mean=mean_t, **{key: other_val})

            # must specify either std or logstd, but not both
            with pytest.raises(ValueError,
                               match='Either `std` or `logstd` must be '
                               'specified, but not both.'):
                _ = _MyBaseNormal(mean=mean_t, std=std_t, logstd=logstd_t)

            with pytest.raises(ValueError,
                               match='Either `std` or `logstd` must be '
                               'specified, but not both.'):
                _ = _MyBaseNormal(mean=mean_t, std=None, logstd=None)

            # nan test
            with pytest.raises(Exception,
                               match='Infinity or NaN value encountered'):
                _ = _MyBaseNormal(mean=T.as_tensor(np.nan, dtype=dtype),
                                  logstd=logstd_t,
                                  validate_tensors=True)

            for key, val in mutual_params.items():
                with pytest.raises(Exception,
                                   match='Infinity or NaN value encountered'):
                    _ = _MyBaseNormal(
                        mean=mean_t,
                        validate_tensors=True,
                        **{key: T.as_tensor(np.nan, dtype=dtype)})

            normal = _MyBaseNormal(mean=mean_t,
                                   std=T.zeros_like(std_t),
                                   validate_tensors=True)
            with pytest.raises(Exception,
                               match='Infinity or NaN value encountered'):
                _ = normal.logstd
Example #5
0
        def do_test(low, high, dtype):
            # test(n_samples=n_samples)
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.truncated_normal(mean_t,
                                          std_t,
                                          n_samples=n_samples,
                                          low=low,
                                          high=high)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample value range
            x = T.to_numpy(t)
            if low is not None:
                np.testing.assert_array_less(
                    (low * std + mean - 1e-7) * np.ones_like(x), x)
            if high is not None:
                np.testing.assert_array_less(
                    x,
                    np.ones_like(x) * high * std + mean + 1e-7)

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(
                                  T.random.truncated_normal_log_pdf,
                                  mean=mean_t,
                                  std=std_t,
                                  logstd=logstd_t,
                                  low=low,
                                  high=high,
                                  log_zero=log_zero,
                              ),
                              np_log_prob=log_prob(x, low, high))
            do_check_log_prob(
                given=t *
                10.,  # where the majority is out of [low, high] range
                batch_ndims=len(x.shape),
                Z_log_prob_fn=partial(
                    T.random.truncated_normal_log_pdf,
                    mean=mean_t,
                    std=std_t,
                    logstd=logstd_t,
                    low=low,
                    high=high,
                    log_zero=log_zero,
                ),
                np_log_prob=log_prob(x * 10., low, high))

            # test(n_samples=None)
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.truncated_normal(mean_t, std_t, low=low, high=high)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test sample value range
            x = T.to_numpy(t)
            if low is not None:
                np.testing.assert_array_less(
                    (low * std + mean - 1e-7) * np.ones_like(x), x)
            if high is not None:
                np.testing.assert_array_less(
                    x,
                    np.ones_like(x) * high * std + mean + 1e-7)

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(
                                  T.random.truncated_normal_log_pdf,
                                  mean=mean_t,
                                  std=std_t,
                                  logstd=logstd_t,
                                  low=low,
                                  high=high,
                                  log_zero=log_zero,
                              ),
                              np_log_prob=log_prob(x, low, high))
            do_check_log_prob(
                given=t *
                10.,  # where the majority is out of [low, high] range
                batch_ndims=len(x.shape),
                Z_log_prob_fn=partial(
                    T.random.truncated_normal_log_pdf,
                    mean=mean_t,
                    std=std_t,
                    logstd=logstd_t,
                    low=low,
                    high=high,
                    log_zero=log_zero,
                ),
                np_log_prob=log_prob(x * 10., low, high))

            # test reparameterized
            w = np.random.randn(2, 3, 4)

            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.truncated_normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

            # test not reparameterized
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.truncated_normal(
                mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

            # given has lower rank than params, broadcasted to match param
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            assert_allclose(T.random.truncated_normal_log_pdf(
                T.float_scalar(0.),
                mean_t,
                std_t,
                logstd_t,
                low=low,
                high=high,
                log_zero=log_zero),
                            log_prob(0., low=low, high=high),
                            rtol=1e-4)

            # dtype mismatch
            with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
                _ = T.random.truncated_normal(T.as_tensor(mean, T.float32),
                                              T.as_tensor(std, T.float64),
                                              low=low,
                                              high=high)

            # check numerics
            mean_t = T.as_tensor(mean)
            std_t = T.zeros_like(mean_t)
            logstd_t = T.as_tensor(T.log(std_t))
            t = T.random.normal(mean_t, std_t)
            with pytest.raises(Exception,
                               match='Infinity or NaN value encountered'):
                _ = T.random.truncated_normal_log_pdf(t,
                                                      mean_t,
                                                      std_t,
                                                      logstd_t,
                                                      validate_tensors=True)
Example #6
0
    def test_normal(self):
        mean = np.random.randn(2, 3, 4)
        logstd = np.random.randn(3, 4)
        std = np.exp(logstd)

        def log_prob(given):
            # np.log(np.exp(-(given - mean) ** 2 / (2. * std ** 2)) /
            #        (np.sqrt(2 * np.pi) * std))
            return (-(given - mean)**2 * (0.5 * np.exp(-2. * logstd)) -
                    np.log(np.sqrt(2 * np.pi)) - logstd)

        # test n_samples by manual expanding the param shape
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.cast(T.expand(T.as_tensor(mean), [n_samples, 2, 3, 4]),
                            dtype)
            std_t = T.cast(T.expand(T.as_tensor(std), [n_samples, 1, 3, 4]),
                           dtype)
            logstd_t = T.cast(
                T.expand(T.as_tensor(logstd), [n_samples, 1, 3, 4]), dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test with n_samples
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t, n_samples=n_samples)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test no n_samples
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test log_prob
            x = T.to_numpy(t)
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test reparameterized
        w = np.random.randn(2, 3, 4)

        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

        # test not reparameterized
        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

        # given has lower rank than params, broadcasted to match param
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            for val in (0., 1., -1.):
                assert_allclose(T.random.normal_log_pdf(
                    T.float_scalar(val), mean_t, logstd_t),
                                log_prob(val),
                                rtol=1e-4)

        # dtype mismatch
        with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
            _ = T.random.normal(T.as_tensor(mean, T.float32),
                                T.as_tensor(std, T.float64))

        # check numerics
        mean_t = T.as_tensor(mean)
        std_t = T.zeros_like(mean_t)
        logstd_t = T.as_tensor(T.log(std_t))
        t = T.random.normal(mean_t, std_t)
        with pytest.raises(Exception,
                           match='Infinity or NaN value encountered'):
            _ = T.random.normal_log_pdf(t,
                                        mean_t,
                                        logstd_t,
                                        validate_tensors=True)
Example #7
0
    def test_data_dependent_initializer(self):
        data_init = _MyDataDependentInitializer([])

        # construct with the initializer
        data_init.watcher.clear()
        layer = tk.layers.Sequential(
            tk.layers.Linear(5, 3, data_init=data_init))
        self.assertEqual(data_init.watcher, [])
        x = T.random.randn([2, 5])
        y = T.random.randn([2, 5])
        _ = layer(x)
        _ = layer(y)
        self.assertListEqual(data_init.watcher, [(layer[0], [x])])

        # set_initialized(False) to re-enable the initializer
        data_init.watcher.clear()
        tk.init.set_initialized(layer, False)
        x = T.random.randn([2, 5])
        y = T.random.randn([2, 5])
        _ = layer(x)
        _ = layer(y)
        self.assertListEqual(data_init.watcher, [(layer[0], [x])])

        # set_initialize(True) to disable newly constructed data-init
        data_init.watcher.clear()
        layer = tk.layers.Sequential(
            tk.layers.Linear(5, 3, data_init=data_init))
        tk.init.set_initialized(layer, True)
        x = T.random.randn([2, 5])
        _ = layer(x)
        self.assertListEqual(data_init.watcher, [])

        # remove the data-dependent initializers
        data_init.watcher.clear()
        layer = tk.layers.Sequential(
            tk.layers.Linear(5, 3, data_init=data_init))
        tk.init.remove_data_dependent_initializers(layer)
        tk.init.set_initialized(layer, False)
        x = T.random.randn([2, 5])
        _ = layer(x)
        self.assertListEqual(data_init.watcher, [])

        # also `set_initialized` will affect layers with `set_initialized()`
        # method, e.g., `ActNorm`
        x = T.random.randn([2, 3, 5])
        layer = tk.layers.jit_compile(tk.layers.ActNorm(5))
        self.assertFalse(layer.flow.initialized)

        tk.init.set_initialized(layer)
        self.assertTrue(layer.flow.initialized)
        assert_allclose(layer(x), x, rtol=1e-4, atol=1e-6)

        tk.init.set_initialized(layer, False)
        self.assertFalse(layer.flow.initialized)
        y = layer(x)
        y_mean, y_var = T.calculate_mean_and_var(y, axis=[0, 1])
        assert_allclose(y_mean, T.zeros_like(y_mean), rtol=1e-4, atol=1e-6)
        assert_allclose(y_var, T.ones_like(y_var), rtol=1e-4, atol=1e-6)

        self.assertTrue(layer.flow.initialized)
        assert_allclose(layer(x), y, rtol=1e-4, atol=1e-6)
Example #8
0
def stepwise_average_check(ctx, factory, update_fn, get_fn):
    def clone_state(val):
        if isinstance(val, dict):
            return {k: clone_state(v) for k, v in val.items()}
        elif isinstance(val, list):
            return [clone_state(v) for v in val]
        elif isinstance(val, (T.Tensor, T.Variable)):
            return T.copy(val)
        elif isinstance(val, np.ndarray):
            return np.copy(val)
        else:
            return copy.copy(val)

    T.random.seed(1234)
    weights = [
        T.variable(shape=[4], initializer=tk.init.zeros, requires_grad=False),
        T.variable(shape=[3], initializer=tk.init.zeros, requires_grad=False),
    ]
    answers = [clone_state(w) for w in weights]
    inputs_1 = T.random.randn([7, 4])
    inputs_2 = T.random.randn([7, 3])

    # do a scan
    avg = factory(weights)
    the_states = []
    the_outputs = []
    num_updates = 0

    for batch_vals in zip(inputs_1, inputs_2):
        for weight, val in zip(weights, batch_vals):
            T.assign(weight, val)

        the_states.append(clone_state(avg.get_state_dict()))
        avg.update()

        with avg.temporarily_commit():
            the_outputs.extend(clone_state(w) for w in weights)
            for i, val in enumerate(batch_vals):
                answers[i] = update_fn(answers[i], val, num_updates)
            num_updates += 1
            for weight, ans in zip(weights, answers):
                assert_allclose(weight, get_fn(ans, num_updates), rtol=1e-4, atol=1e-6)

        for weight, val in zip(weights, batch_vals):
            assert_allclose(weight, val, rtol=1e-4, atol=1e-6)

    # test enabled = False
    avg = factory(weights, enabled=False)
    for x1, x2, state, output in zip(inputs_1, inputs_2, the_states, the_outputs):
        batch_vals = [x1, x2]
        for weight, val in zip(weights, batch_vals):
            T.assign(weight, val)
        avg.update()

    avg.commit()  # should still affect weights even if enabled is False
    for avg_val in avg.get_state_dict()['averages']:
        assert_allclose(avg_val, T.zeros_like(avg_val), rtol=1e-4, atol=1e-6)
    for weight in weights:
        assert_allclose(weight, T.zeros_like(weight), rtol=1e-4, atol=1e-6)

    # do another scan using backup states
    avg = factory(weights, enabled=False)
    avg.set_enabled(True)
    for x1, x2, state, output in zip(inputs_1, inputs_2, the_states, the_outputs):
        batch_vals = [x1, x2]
        for weight, val in zip(weights, batch_vals):
            T.assign(weight, val)

        avg.set_state_dict(state)
        avg.update()

        with avg.temporarily_commit():
            the_outputs.extend(clone_state(w) for w in weights)
        for weight, val in zip(weights, batch_vals):
            assert_allclose(weight, val, rtol=1e-4, atol=1e-6)

    # try set bad state
    avg = factory(weights)
    state = dict(avg.get_state_dict())
    state['averages'] = []
    with pytest.raises(ValueError, match='Bad state'):
        avg.set_state_dict(state)