def align_gaussian(new_inputs, old): """ Align data of a Gaussian distribution to a new ``inputs`` shape. """ assert isinstance(new_inputs, OrderedDict) assert isinstance(old, Gaussian) info_vec = old.info_vec precision = old.precision # Align int inputs. # Since these are are managed as in Tensor, we can defer to align_tensor(). new_ints = OrderedDict( (k, d) for k, d in new_inputs.items() if d.dtype != 'real') old_ints = OrderedDict( (k, d) for k, d in old.inputs.items() if d.dtype != 'real') if new_ints != old_ints: info_vec = align_tensor(new_ints, Tensor(info_vec, old_ints)) precision = align_tensor(new_ints, Tensor(precision, old_ints)) # Align real inputs, which are all concatenated in the rightmost dims. new_offsets, new_dim = _compute_offsets(new_inputs) old_offsets, old_dim = _compute_offsets(old.inputs) assert info_vec.shape[-1:] == (old_dim, ) assert precision.shape[-2:] == (old_dim, old_dim) if new_offsets != old_offsets: old_info_vec = info_vec old_precision = precision info_vec = BlockVector(old_info_vec.shape[:-1] + (new_dim, )) precision = BlockMatrix(old_info_vec.shape[:-1] + (new_dim, new_dim)) for k1, new_offset1 in new_offsets.items(): if k1 not in old_offsets: continue offset1 = old_offsets[k1] num_elements1 = old.inputs[k1].num_elements old_slice1 = slice(offset1, offset1 + num_elements1) new_slice1 = slice(new_offset1, new_offset1 + num_elements1) info_vec[..., new_slice1] = old_info_vec[..., old_slice1] for k2, new_offset2 in new_offsets.items(): if k2 not in old_offsets: continue offset2 = old_offsets[k2] num_elements2 = old.inputs[k2].num_elements old_slice2 = slice(offset2, offset2 + num_elements2) new_slice2 = slice(new_offset2, new_offset2 + num_elements2) precision[..., new_slice1, new_slice2] = old_precision[..., old_slice1, old_slice2] info_vec = info_vec.as_tensor() precision = precision.as_tensor() return info_vec, precision
def eager_cat_homogeneous(name, part_name, *parts): assert parts output = parts[0].output inputs = OrderedDict([(part_name, None)]) for part in parts: assert part.output == output assert part_name in part.inputs inputs.update(part.inputs) int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != "real") real_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype == "real") inputs = int_inputs.copy() inputs.update(real_inputs) discretes = [] info_vecs = [] precisions = [] for part in parts: inputs[part_name] = part.inputs[part_name] int_inputs[part_name] = inputs[part_name] shape = tuple(d.size for d in int_inputs.values()) if isinstance(part, Gaussian): discrete = None gaussian = part elif issubclass(type(part), GaussianMixture ): # TODO figure out why isinstance isn't working discrete, gaussian = part.terms[0], part.terms[1] discrete = align_tensor(int_inputs, discrete).expand(shape) else: raise NotImplementedError("TODO") discretes.append(discrete) info_vec, precision = align_gaussian(inputs, gaussian) info_vecs.append(info_vec.expand(shape + (-1, ))) precisions.append(precision.expand(shape + (-1, -1))) if part_name != name: del inputs[part_name] del int_inputs[part_name] dim = 0 info_vec = torch.cat(info_vecs, dim=dim) precision = torch.cat(precisions, dim=dim) inputs[name] = bint(info_vec.size(dim)) int_inputs[name] = inputs[name] result = Gaussian(info_vec, precision, inputs) if any(d is not None for d in discretes): for i, d in enumerate(discretes): if d is None: discretes[i] = info_vecs[i].new_zeros(info_vecs[i].shape[:-1]) discrete = torch.cat(discretes, dim=dim) result += Tensor(discrete, int_inputs) return result
def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp: # Marginalize out real variables, but keep mixtures lazy. assert all(v in self.inputs for v in reduced_vars) real_vars = frozenset(k for k, d in self.inputs.items() if d.dtype == "real") reduced_reals = reduced_vars & real_vars reduced_ints = reduced_vars - real_vars if not reduced_reals: return None # defer to default implementation inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals) if reduced_reals == real_vars: result = self._log_normalizer else: int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != 'real') offsets, _ = _compute_offsets(self.inputs) index = [] for key, domain in inputs.items(): if domain.dtype == 'real': index.extend( range(offsets[key], offsets[key] + domain.num_elements)) index = torch.tensor(index) loc = self.loc[..., index] self_scale_tri = torch.inverse(torch.cholesky( self.precision)).transpose(-1, -2) self_covariance = torch.matmul( self_scale_tri, self_scale_tri.transpose(-1, -2)) covariance = self_covariance[..., index.unsqueeze(-1), index] scale_tri = torch.cholesky(covariance) inv_scale_tri = torch.inverse(scale_tri) precision = torch.matmul(inv_scale_tri.transpose(-1, -2), inv_scale_tri) reduced_dim = sum(self.inputs[k].num_elements for k in reduced_reals) log_det_term = _log_det_tri(self_scale_tri) - _log_det_tri( scale_tri) log_prob = Tensor( log_det_term + 0.5 * math.log(2 * math.pi) * reduced_dim, int_inputs) result = log_prob + Gaussian(loc, precision, inputs) return result.reduce(ops.logaddexp, reduced_ints) elif op is ops.add: for v in reduced_vars: if self.inputs[v].dtype == 'real': raise ValueError( "Cannot sum along a real dimension: {}".format( repr(v))) # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian(). old_ints = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') new_ints = OrderedDict( (k, v) for k, v in old_ints.items() if k not in reduced_vars) inputs = OrderedDict((k, v) for k, v in self.inputs.items() if k not in reduced_vars) precision = Tensor(self.precision, old_ints).reduce(ops.add, reduced_vars) precision_loc = Tensor(_mv(self.precision, self.loc), old_ints).reduce(ops.add, reduced_vars) assert precision.inputs == new_ints assert precision_loc.inputs == new_ints loc = Tensor(sym_solve_mv(precision.data, precision_loc.data), new_ints) expanded_loc = align_tensor(old_ints, loc) quadratic_term = Tensor( _vmv(self.precision, expanded_loc - self.loc), old_ints).reduce(ops.add, reduced_vars) assert quadratic_term.inputs == new_ints likelihood = -0.5 * quadratic_term return likelihood + Gaussian(loc.data, precision.data, inputs) return None # defer to default implementation