Exemplo n.º 1
0
 def q(self,
       x: T.Tensor,
       observed: Optional[Mapping[str, TensorOrData]] = None,
       n_z: Optional[int] = None) -> tk.BayesianNet:
     net = tk.BayesianNet(observed=observed)
     hx = self.hx_for_qz(T.cast(x, dtype=T.float32))
     z_mean = self.qz_mean(hx)
     z_logstd = self.qz_logstd(hx)
     z = net.add('z',
                 tk.Normal(mean=z_mean, logstd=z_logstd, event_ndims=1),
                 n_samples=n_z)
     return net
Exemplo n.º 2
0
 def plot_samples(epoch=None):
     epoch = epoch or loop.epoch
     with tk.layers.scoped_eval_mode(vae), T.no_grad():
         logits = vae.p(n_z=100)['x'].distribution.logits
         images = T.reshape(
             T.cast(T.clip(T.nn.sigmoid(logits) * 255., 0., 255.), dtype=T.uint8),
             [-1, 28, 28],
         )
     utils.save_images_collection(
         images=T.to_numpy(images),
         filename=exp.abspath(f'plotting/{epoch}.png'),
         grid_size=(10, 10),
     )
Exemplo n.º 3
0
 def q(self,
       x: T.Tensor,
       observed: Optional[Mapping[str, TensorOrData]] = None,
       n_z: Optional[int] = None) -> tk.BayesianNet:
     net = tk.BayesianNet(observed=observed)
     hx = self.hx_for_qz(T.cast(x, dtype=T.float32))
     z_mean = self.qz_mean(hx)
     z_logstd = self.qz_logstd(hx)
     z_logstd = T.maybe_clip(z_logstd, min_val=self.config.qz_logstd_min)
     qz = tk.FlowDistribution(
         tk.Normal(mean=z_mean, logstd=z_logstd, event_ndims=1),
         self.posterior_flow,
     )
     z = net.add('z', qz, n_samples=n_z)
     return net
Exemplo n.º 4
0
    def test_cross_entropy(self):
        def softmax(x, axis):
            x_max = np.max(x, axis=axis, keepdims=True)
            x_exp = np.exp(x - x_max)
            return x_exp / np.sum(x_exp, axis=axis, keepdims=True)

        def sparse_cross_entropy(logits, labels, reduction, negative):
            logits_max = np.max(logits, axis=-1, keepdims=True)
            logits_max_reduced = np.squeeze(logits_max, axis=-1)
            out = np.sum(logits * labels, axis=-1) - logits_max_reduced
            out -= np.log(np.sum(np.exp(logits - logits_max), axis=-1))

            if reduction == 'sum':
                out = np.sum(out)
            elif reduction == 'mean':
                out = np.mean(out)
            else:
                assert (reduction == 'none')

            if not negative:
                out = -out
            return out

        def cross_entropy(logits, labels, reduction, negative):
            k = logits.shape[-1]
            sparse_labels = np.eye(k, dtype=logits.dtype)[labels]
            return sparse_cross_entropy(logits, sparse_labels, reduction,
                                        negative)

        logits = np.random.randn(2, 3, 4, 5, 6)
        sparse_labels = softmax(np.random.randn(3, 4, 5, 6), axis=-1)
        labels = np.argmax(sparse_labels, axis=-1)

        self.assertEqual(sparse_labels.shape, (3, 4, 5, 6))
        self.assertEqual(labels.shape, (3, 4, 5))
        self.assertEqual(set(labels.flatten().tolist()), {0, 1, 2, 3, 4, 5})

        _f = T.as_tensor

        for reduction in ['none', 'mean', 'sum']:
            for negative in [False, True]:
                # test cross_entropy
                ans = cross_entropy(logits, labels, reduction, negative)
                out = T.nn.cross_entropy_with_logits(_f(logits), _f(labels),
                                                     reduction, negative)
                assert_allclose(ans, out)

                # test cross_entropy with int32 labels
                ans = cross_entropy(logits, labels, reduction, negative)
                out = T.nn.cross_entropy_with_logits(
                    _f(logits), T.cast(_f(labels), dtype=T.int32), reduction,
                    negative)
                assert_allclose(ans, out)

                # test cross_entropy on 2d
                ans = cross_entropy(logits[0, 0, 0], labels[0, 0], reduction,
                                    negative)
                out = T.nn.cross_entropy_with_logits(_f(logits[0, 0, 0]),
                                                     _f(labels[0, 0]),
                                                     reduction, negative)
                assert_allclose(ans, out)

                # test sparse_cross_entropy
                ans = sparse_cross_entropy(logits, sparse_labels, reduction,
                                           negative)
                out = T.nn.sparse_cross_entropy_with_logits(
                    _f(logits), _f(sparse_labels), reduction, negative)
                assert_allclose(ans, out)

                # test sparse_cross_entropy on 2d
                ans = sparse_cross_entropy(logits[0, 0, 0], sparse_labels[0,
                                                                          0],
                                           reduction, negative)
                out = T.nn.sparse_cross_entropy_with_logits(
                    _f(logits[0, 0, 0]), _f(sparse_labels[0, 0]), reduction,
                    negative)
                assert_allclose(ans, out)

        # invalid `reduction` argument should raise error
        with pytest.raises(Exception):
            _ = T.nn.cross_entropy_with_logits(_f(logits), _f(labels),
                                               'invalid')

        with pytest.raises(Exception):
            _ = T.nn.sparse_cross_entropy_with_logits(_f(logits), _f(labels),
                                                      'invalid')

        # validation for the shape of logits and labels
        with pytest.raises(Exception, match='(cannot broadcast|shape|size)'):
            # logits and labels shape mismatch
            _ = T.nn.cross_entropy_with_logits(_f(logits),
                                               _f(labels[..., :-1]))

        with pytest.raises(Exception, match='must be at least 2d'):
            # logits and labels rank too low
            _ = T.nn.cross_entropy_with_logits(_f(logits[0, 0, 0, 0]),
                                               _f(labels[0, 0, 0]))

        with pytest.raises(Exception, match='(cannot broadcast|shape|size)'):
            # logits and labels shape mismatch
            _ = T.nn.sparse_cross_entropy_with_logits(_f(logits[..., :-1]),
                                                      _f(sparse_labels))

        with pytest.raises(Exception, match='must be at least 2d'):
            # logits and labels rank too low
            _ = T.nn.sparse_cross_entropy_with_logits(
                _f(logits[0, 0, 0, 0]), _f(sparse_labels[0, 0, 0, 0]))
Exemplo n.º 5
0
    def test_construct(self):
        mean = np.random.randn(3, 4)
        logstd = np.random.randn(2, 3, 4)
        std = np.exp(logstd)

        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype=dtype)
            std_t = T.as_tensor(std, dtype=dtype)
            logstd_t = T.as_tensor(logstd, dtype=dtype)
            mutual_params = {'std': std_t, 'logstd': logstd_t}

            # construct from mean & std/logstd
            for key, val in mutual_params.items():
                other_key = [k for k in mutual_params if k != key][0]
                normal = _MyBaseNormal(mean=mean_t,
                                       event_ndims=1,
                                       **{key: val})
                self.assertEqual(normal.continuous, True)
                self.assertEqual(normal.reparameterized, True)
                self.assertEqual(normal.min_event_ndims, 0)
                self.assertEqual(normal.event_ndims, 1)
                self.assertIs(normal.mean, mean_t)
                self.assertIs(getattr(normal, key), val)
                assert_allclose(getattr(normal, other_key),
                                mutual_params[other_key],
                                rtol=1e-4)
                self.assertEqual(normal._mutual_params, {key: val})

                # mean and std/logstd must have the same dtype
                for other_dtype in float_dtypes:
                    if other_dtype != dtype:
                        other_val = T.cast(val, other_dtype)
                        with pytest.raises(ValueError,
                                           match=f'`{key}.dtype` != `mean.'
                                           f'dtype`: {other_dtype} vs '
                                           f'{dtype}'):
                            _ = _MyBaseNormal(mean=mean_t, **{key: other_val})

            # must specify either std or logstd, but not both
            with pytest.raises(ValueError,
                               match='Either `std` or `logstd` must be '
                               'specified, but not both.'):
                _ = _MyBaseNormal(mean=mean_t, std=std_t, logstd=logstd_t)

            with pytest.raises(ValueError,
                               match='Either `std` or `logstd` must be '
                               'specified, but not both.'):
                _ = _MyBaseNormal(mean=mean_t, std=None, logstd=None)

            # nan test
            with pytest.raises(Exception,
                               match='Infinity or NaN value encountered'):
                _ = _MyBaseNormal(mean=T.as_tensor(np.nan, dtype=dtype),
                                  logstd=logstd_t,
                                  validate_tensors=True)

            for key, val in mutual_params.items():
                with pytest.raises(Exception,
                                   match='Infinity or NaN value encountered'):
                    _ = _MyBaseNormal(
                        mean=mean_t,
                        validate_tensors=True,
                        **{key: T.as_tensor(np.nan, dtype=dtype)})

            normal = _MyBaseNormal(mean=mean_t,
                                   std=T.zeros_like(std_t),
                                   validate_tensors=True)
            with pytest.raises(Exception,
                               match='Infinity or NaN value encountered'):
                _ = normal.logstd
Exemplo n.º 6
0
    def test_normal(self):
        mean = np.random.randn(2, 3, 4)
        logstd = np.random.randn(3, 4)
        std = np.exp(logstd)

        def log_prob(given):
            # np.log(np.exp(-(given - mean) ** 2 / (2. * std ** 2)) /
            #        (np.sqrt(2 * np.pi) * std))
            return (-(given - mean)**2 * (0.5 * np.exp(-2. * logstd)) -
                    np.log(np.sqrt(2 * np.pi)) - logstd)

        # test n_samples by manual expanding the param shape
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.cast(T.expand(T.as_tensor(mean), [n_samples, 2, 3, 4]),
                            dtype)
            std_t = T.cast(T.expand(T.as_tensor(std), [n_samples, 1, 3, 4]),
                           dtype)
            logstd_t = T.cast(
                T.expand(T.as_tensor(logstd), [n_samples, 1, 3, 4]), dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test with n_samples
        for dtype in float_dtypes:
            # test sample dtype and shape
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t, n_samples=n_samples)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())
            self.assertEqual(T.shape(t), [n_samples, 2, 3, 4])

            # test sample mean
            x = T.to_numpy(t)
            x_mean = np.mean(x, axis=0)
            np.testing.assert_array_less(
                np.abs(x_mean - mean),
                np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0),
                        [2, 1, 1]))

            # test log_prob
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test no n_samples
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            std_t = T.as_tensor(std, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            t = T.random.normal(mean_t, std_t)
            self.assertEqual(T.get_dtype(t), dtype)
            self.assertEqual(T.get_device(t), T.current_device())

            # test log_prob
            x = T.to_numpy(t)
            do_check_log_prob(given=t,
                              batch_ndims=len(x.shape),
                              Z_log_prob_fn=partial(T.random.normal_log_pdf,
                                                    mean=mean_t,
                                                    logstd=logstd_t),
                              np_log_prob=log_prob(x))

        # test reparameterized
        w = np.random.randn(2, 3, 4)

        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)])
            assert_allclose(mean_grad, w, rtol=1e-4)
            assert_allclose(std_grad,
                            np.sum(T.to_numpy((t - w_t * mean_t) / std_t),
                                   axis=0),
                            rtol=1e-4)

        # test not reparameterized
        for dtype in float_dtypes:
            w_t = T.requires_grad(T.as_tensor(w))
            mean_t = T.requires_grad(T.as_tensor(mean, dtype))
            std_t = T.requires_grad(T.as_tensor(std, dtype))
            t = w_t * T.random.normal(mean_t, std_t, reparameterized=False)
            [mean_grad, std_grad] = T.grad([t], [mean_t, std_t],
                                           [T.ones_like(t)],
                                           allow_unused=True)
            self.assertTrue(T.is_null_grad(mean_t, mean_grad))
            self.assertTrue(T.is_null_grad(std_t, std_grad))

        # given has lower rank than params, broadcasted to match param
        for dtype in float_dtypes:
            mean_t = T.as_tensor(mean, dtype)
            logstd_t = T.as_tensor(logstd, dtype)
            for val in (0., 1., -1.):
                assert_allclose(T.random.normal_log_pdf(
                    T.float_scalar(val), mean_t, logstd_t),
                                log_prob(val),
                                rtol=1e-4)

        # dtype mismatch
        with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'):
            _ = T.random.normal(T.as_tensor(mean, T.float32),
                                T.as_tensor(std, T.float64))

        # check numerics
        mean_t = T.as_tensor(mean)
        std_t = T.zeros_like(mean_t)
        logstd_t = T.as_tensor(T.log(std_t))
        t = T.random.normal(mean_t, std_t)
        with pytest.raises(Exception,
                           match='Infinity or NaN value encountered'):
            _ = T.random.normal_log_pdf(t,
                                        mean_t,
                                        logstd_t,
                                        validate_tensors=True)
Exemplo n.º 7
0
def calculate_acc(logits: T.Tensor, y: T.Tensor) -> T.Tensor:
    with T.no_grad():
        out_y = T.argmax(logits, axis=-1)
        return T.reduce_mean(T.cast(T.equal(out_y, y), dtype=T.float32))