def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self if any(self.inputs[k].dtype != 'real' for k in sampled_vars): raise ValueError( 'Sampling from non-normalized Gaussian mixtures is intentionally ' 'not implemented. You probably want to normalize. To work around, ' 'add a zero Tensor/Array with given inputs.') # Partition inputs into sample_inputs + int_inputs + real_inputs. sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') inputs = sample_inputs.copy() inputs.update(int_inputs) if sampled_vars == frozenset(real_inputs): shape = sample_shape + self.info_vec.shape backend = get_backend() if backend != "numpy": from importlib import import_module dist = import_module(funsor.distribution. BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]) sample_args = (shape, ) if rng_key is None else (rng_key, shape) white_noise = dist.Normal.dist_class(0, 1).sample(*sample_args) else: white_noise = np.random.randn(*shape) white_noise = ops.unsqueeze(white_noise, -1) white_vec = ops.triangular_solve(self.info_vec[..., None], self._precision_chol) sample = ops.triangular_solve(white_noise + white_vec, self._precision_chol, transpose=True)[..., 0] offsets, _ = _compute_offsets(real_inputs) results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]:offsets[key] + domain.num_elements] data = data.reshape(shape[:-1] + domain.shape) point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) results.append(self.log_normalizer) return reduce(ops.add, results) raise NotImplementedError( 'TODO implement partial sampling of real variables')
def unscaled_sample(self, sampled_vars, sample_inputs): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self if any(self.inputs[k].dtype != 'real' for k in sampled_vars): raise ValueError( 'Sampling from non-normalized Gaussian mixtures is intentionally ' 'not implemented. You probably want to normalize. To work around, ' 'add a zero Tensor/Array with given inputs.') # Partition inputs into sample_inputs + int_inputs + real_inputs. sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') inputs = sample_inputs.copy() inputs.update(int_inputs) if sampled_vars == frozenset(real_inputs): shape = sample_shape + self.info_vec.shape # TODO: revise the logic here; `key` is required for JAX normal sampler white_noise = funsor.testing.randn(shape + (1, )) white_vec = ops.triangular_solve(self.info_vec[..., None], self._precision_chol) sample = ops.triangular_solve(white_noise + white_vec, self._precision_chol, transpose=True)[..., 0] offsets, _ = _compute_offsets(real_inputs) results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]:offsets[key] + domain.num_elements] data = data.reshape(shape[:-1] + domain.shape) point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) results.append(self.log_normalizer) return reduce(ops.add, results) raise NotImplementedError( 'TODO implement partial sampling of real variables')
def log_normalizer(self): dim = self.precision.shape[-1] log_det_term = _log_det_tri(self._precision_chol) loc_info_vec_term = 0.5 * (ops.triangular_solve( self.info_vec[..., None], self._precision_chol)[..., 0]**2).sum(-1) data = 0.5 * dim * math.log( 2 * math.pi) - log_det_term + loc_info_vec_term inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') return Tensor(data, inputs)
def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp: # Marginalize out real variables, but keep mixtures lazy. assert all(v in self.inputs for v in reduced_vars) real_vars = frozenset(k for k, d in self.inputs.items() if d.dtype == "real") reduced_reals = reduced_vars & real_vars reduced_ints = reduced_vars - real_vars if not reduced_reals: return None # defer to default implementation inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals) if reduced_reals == real_vars: result = self.log_normalizer else: int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != 'real') offsets, _ = _compute_offsets(self.inputs) a = [] b = [] for key, domain in self.inputs.items(): if domain.dtype == 'real': block = ops.new_arange( self.info_vec, offsets[key], offsets[key] + domain.num_elements, 1) (b if key in reduced_vars else a).append(block) a = ops.cat(-1, *a) b = ops.cat(-1, *b) prec_aa = self.precision[..., a[..., None], a] prec_ba = self.precision[..., b[..., None], a] prec_bb = self.precision[..., b[..., None], b] prec_b = ops.cholesky(prec_bb) prec_a = ops.triangular_solve(prec_ba, prec_b) prec_at = ops.transpose(prec_a, -1, -2) precision = prec_aa - ops.matmul(prec_at, prec_a) info_a = self.info_vec[..., a] info_b = self.info_vec[..., b] b_tmp = ops.triangular_solve(info_b[..., None], prec_b) info_vec = info_a - ops.matmul(prec_at, b_tmp)[..., 0] log_prob = Tensor( 0.5 * len(b) * math.log(2 * math.pi) - _log_det_tri(prec_b) + 0.5 * (b_tmp[..., 0]**2).sum(-1), int_inputs) result = log_prob + Gaussian(info_vec, precision, inputs) return result.reduce(ops.logaddexp, reduced_ints) elif op is ops.add: for v in reduced_vars: if self.inputs[v].dtype == 'real': raise ValueError( "Cannot sum along a real dimension: {}".format( repr(v))) # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian(). old_ints = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') new_ints = OrderedDict( (k, v) for k, v in old_ints.items() if k not in reduced_vars) inputs = OrderedDict((k, v) for k, v in self.inputs.items() if k not in reduced_vars) info_vec = Tensor(self.info_vec, old_ints).reduce(ops.add, reduced_vars) precision = Tensor(self.precision, old_ints).reduce(ops.add, reduced_vars) assert info_vec.inputs == new_ints assert precision.inputs == new_ints return Gaussian(info_vec.data, precision.data, inputs) return None # defer to default implementation