Example #1
0
    def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
        sampled_vars = sampled_vars.intersection(self.inputs)
        if not sampled_vars:
            return self
        if any(self.inputs[k].dtype != 'real' for k in sampled_vars):
            raise ValueError(
                'Sampling from non-normalized Gaussian mixtures is intentionally '
                'not implemented. You probably want to normalize. To work around, '
                'add a zero Tensor/Array with given inputs.')

        # Partition inputs into sample_inputs + int_inputs + real_inputs.
        sample_inputs = OrderedDict(
            (k, d) for k, d in sample_inputs.items() if k not in self.inputs)
        sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        real_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
        inputs = sample_inputs.copy()
        inputs.update(int_inputs)

        if sampled_vars == frozenset(real_inputs):
            shape = sample_shape + self.info_vec.shape
            backend = get_backend()
            if backend != "numpy":
                from importlib import import_module
                dist = import_module(funsor.distribution.
                                     BACKEND_TO_DISTRIBUTIONS_BACKEND[backend])
                sample_args = (shape, ) if rng_key is None else (rng_key,
                                                                 shape)
                white_noise = dist.Normal.dist_class(0, 1).sample(*sample_args)
            else:
                white_noise = np.random.randn(*shape)
            white_noise = ops.unsqueeze(white_noise, -1)

            white_vec = ops.triangular_solve(self.info_vec[..., None],
                                             self._precision_chol)
            sample = ops.triangular_solve(white_noise + white_vec,
                                          self._precision_chol,
                                          transpose=True)[..., 0]
            offsets, _ = _compute_offsets(real_inputs)
            results = []
            for key, domain in real_inputs.items():
                data = sample[...,
                              offsets[key]:offsets[key] + domain.num_elements]
                data = data.reshape(shape[:-1] + domain.shape)
                point = Tensor(data, inputs)
                assert point.output == domain
                results.append(Delta(key, point))
            results.append(self.log_normalizer)
            return reduce(ops.add, results)

        raise NotImplementedError(
            'TODO implement partial sampling of real variables')
Example #2
0
    def unscaled_sample(self, sampled_vars, sample_inputs):
        sampled_vars = sampled_vars.intersection(self.inputs)
        if not sampled_vars:
            return self
        if any(self.inputs[k].dtype != 'real' for k in sampled_vars):
            raise ValueError(
                'Sampling from non-normalized Gaussian mixtures is intentionally '
                'not implemented. You probably want to normalize. To work around, '
                'add a zero Tensor/Array with given inputs.')

        # Partition inputs into sample_inputs + int_inputs + real_inputs.
        sample_inputs = OrderedDict(
            (k, d) for k, d in sample_inputs.items() if k not in self.inputs)
        sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        real_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
        inputs = sample_inputs.copy()
        inputs.update(int_inputs)

        if sampled_vars == frozenset(real_inputs):
            shape = sample_shape + self.info_vec.shape
            # TODO: revise the logic here; `key` is required for JAX normal sampler
            white_noise = funsor.testing.randn(shape + (1, ))
            white_vec = ops.triangular_solve(self.info_vec[..., None],
                                             self._precision_chol)
            sample = ops.triangular_solve(white_noise + white_vec,
                                          self._precision_chol,
                                          transpose=True)[..., 0]
            offsets, _ = _compute_offsets(real_inputs)
            results = []
            for key, domain in real_inputs.items():
                data = sample[...,
                              offsets[key]:offsets[key] + domain.num_elements]
                data = data.reshape(shape[:-1] + domain.shape)
                point = Tensor(data, inputs)
                assert point.output == domain
                results.append(Delta(key, point))
            results.append(self.log_normalizer)
            return reduce(ops.add, results)

        raise NotImplementedError(
            'TODO implement partial sampling of real variables')
Example #3
0
 def log_normalizer(self):
     dim = self.precision.shape[-1]
     log_det_term = _log_det_tri(self._precision_chol)
     loc_info_vec_term = 0.5 * (ops.triangular_solve(
         self.info_vec[..., None], self._precision_chol)[..., 0]**2).sum(-1)
     data = 0.5 * dim * math.log(
         2 * math.pi) - log_det_term + loc_info_vec_term
     inputs = OrderedDict(
         (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
     return Tensor(data, inputs)
Example #4
0
    def eager_reduce(self, op, reduced_vars):
        if op is ops.logaddexp:
            # Marginalize out real variables, but keep mixtures lazy.
            assert all(v in self.inputs for v in reduced_vars)
            real_vars = frozenset(k for k, d in self.inputs.items()
                                  if d.dtype == "real")
            reduced_reals = reduced_vars & real_vars
            reduced_ints = reduced_vars - real_vars
            if not reduced_reals:
                return None  # defer to default implementation

            inputs = OrderedDict((k, d) for k, d in self.inputs.items()
                                 if k not in reduced_reals)
            if reduced_reals == real_vars:
                result = self.log_normalizer
            else:
                int_inputs = OrderedDict(
                    (k, v) for k, v in inputs.items() if v.dtype != 'real')
                offsets, _ = _compute_offsets(self.inputs)
                a = []
                b = []
                for key, domain in self.inputs.items():
                    if domain.dtype == 'real':
                        block = ops.new_arange(
                            self.info_vec, offsets[key],
                            offsets[key] + domain.num_elements, 1)
                        (b if key in reduced_vars else a).append(block)
                a = ops.cat(-1, *a)
                b = ops.cat(-1, *b)
                prec_aa = self.precision[..., a[..., None], a]
                prec_ba = self.precision[..., b[..., None], a]
                prec_bb = self.precision[..., b[..., None], b]
                prec_b = ops.cholesky(prec_bb)
                prec_a = ops.triangular_solve(prec_ba, prec_b)
                prec_at = ops.transpose(prec_a, -1, -2)
                precision = prec_aa - ops.matmul(prec_at, prec_a)

                info_a = self.info_vec[..., a]
                info_b = self.info_vec[..., b]
                b_tmp = ops.triangular_solve(info_b[..., None], prec_b)
                info_vec = info_a - ops.matmul(prec_at, b_tmp)[..., 0]

                log_prob = Tensor(
                    0.5 * len(b) * math.log(2 * math.pi) -
                    _log_det_tri(prec_b) + 0.5 * (b_tmp[..., 0]**2).sum(-1),
                    int_inputs)
                result = log_prob + Gaussian(info_vec, precision, inputs)

            return result.reduce(ops.logaddexp, reduced_ints)

        elif op is ops.add:
            for v in reduced_vars:
                if self.inputs[v].dtype == 'real':
                    raise ValueError(
                        "Cannot sum along a real dimension: {}".format(
                            repr(v)))

            # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian().
            old_ints = OrderedDict(
                (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
            new_ints = OrderedDict(
                (k, v) for k, v in old_ints.items() if k not in reduced_vars)
            inputs = OrderedDict((k, v) for k, v in self.inputs.items()
                                 if k not in reduced_vars)

            info_vec = Tensor(self.info_vec,
                              old_ints).reduce(ops.add, reduced_vars)
            precision = Tensor(self.precision,
                               old_ints).reduce(ops.add, reduced_vars)
            assert info_vec.inputs == new_ints
            assert precision.inputs == new_ints
            return Gaussian(info_vec.data, precision.data, inputs)

        return None  # defer to default implementation