def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple( (k, v if isinstance(v, (Variable, Slice)) else materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Affine funsors are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple( (k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice)) and not (is_affine(v) and affine_inputs(v))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') affine_subs = tuple((k, v) for k, v in subs if is_affine(v) and affine_inputs(v) and not isinstance(v, Variable)) if var_subs: return self._eager_subs_var( var_subs, int_subs + real_subs + affine_subs + lazy_subs) if int_subs: return self._eager_subs_int(int_subs, real_subs + affine_subs + lazy_subs) if real_subs: return self._eager_subs_real(real_subs, affine_subs + lazy_subs) if affine_subs: return self._eager_subs_affine(affine_subs, lazy_subs) return reflect(Subs, self, lazy_subs)
def normalize_with_subs(cls, *args): """ This interpretation is like normalize, except it also evaluates Subs eagerly. This is necessary because we want to convert distribution expressions to normal form in some tests, but do not want to trigger eager patterns that rewrite some distributions (e.g. Normal to Gaussian) since these tests are specifically intended to exercise funsor.distribution.Distribution. """ result = normalize.dispatch(cls, *args)(*args) if result is None: result = lazy.dispatch(cls, *args)(*args) if result is None: result = reflect(cls, *args) return result
def _eager_subs_affine(self, subs, remaining_subs): # Extract an affine representation. affine = OrderedDict() for k, v in subs: const, coeffs = extract_affine(v) if (isinstance(const, Tensor) and all( isinstance(coeff, Tensor) for coeff, _ in coeffs.values())): affine[k] = const, coeffs else: remaining_subs += (k, v), if not affine: return reflect(Subs, self, remaining_subs) # Align integer dimensions. old_int_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') tensors = [ Tensor(self.info_vec, old_int_inputs), Tensor(self.precision, old_int_inputs) ] for const, coeffs in affine.values(): tensors.append(const) tensors.extend(coeff for coeff, _ in coeffs.values()) new_int_inputs, tensors = align_tensors(*tensors, expand=True) tensors = (Tensor(x, new_int_inputs) for x in tensors) old_info_vec = next(tensors).data old_precision = next(tensors).data for old_k, (const, coeffs) in affine.items(): const = next(tensors) for new_k, (coeff, eqn) in coeffs.items(): coeff = next(tensors) coeffs[new_k] = coeff, eqn affine[old_k] = const, coeffs batch_shape = old_info_vec.shape[:-1] # Align real dimensions. old_real_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype == 'real') new_real_inputs = old_real_inputs.copy() for old_k, (const, coeffs) in affine.items(): del new_real_inputs[old_k] for new_k, (coeff, eqn) in coeffs.items(): new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])] new_real_inputs[new_k] = Reals[new_shape] old_offsets, old_dim = _compute_offsets(old_real_inputs) new_offsets, new_dim = _compute_offsets(new_real_inputs) new_inputs = new_int_inputs.copy() new_inputs.update(new_real_inputs) # Construct a blockwise affine representation of the substitution. subs_vector = BlockVector(batch_shape + (old_dim, )) subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim)) for old_k, old_offset in old_offsets.items(): old_size = old_real_inputs[old_k].num_elements old_slice = slice(old_offset, old_offset + old_size) if old_k in new_real_inputs: new_offset = new_offsets[old_k] new_slice = slice(new_offset, new_offset + old_size) subs_matrix[..., new_slice, old_slice] = \ ops.new_eye(self.info_vec, batch_shape + (old_size,)) continue const, coeffs = affine[old_k] old_shape = old_real_inputs[old_k].shape assert const.data.shape == batch_shape + old_shape subs_vector[..., old_slice] = const.data.reshape(batch_shape + (old_size, )) for new_k, new_offset in new_offsets.items(): if new_k in coeffs: coeff, eqn = coeffs[new_k] new_size = new_real_inputs[new_k].num_elements new_slice = slice(new_offset, new_offset + new_size) assert coeff.shape == new_real_inputs[ new_k].shape + old_shape subs_matrix[..., new_slice, old_slice] = \ coeff.data.reshape(batch_shape + (new_size, old_size)) subs_vector = subs_vector.as_tensor() subs_matrix = subs_matrix.as_tensor() subs_matrix_t = ops.transpose(subs_matrix, -1, -2) # Construct the new funsor. Suppose the old Gaussian funsor g has density # g(x) = < x | i - 1/2 P x> # Now define a new funsor f by substituting x = A y + B: # f(y) = g(A y + B) # = < A y + B | i - 1/2 P (A y + B) > # = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B > # =: < y | i' - 1/2 P' y > + C # where P' = At P A and i' = At (i - P B) parametrize a new Gaussian # and C = < B | i - 1/2 P B > parametrize a new Tensor. precision = subs_matrix @ old_precision @ subs_matrix_t info_vec = _mv(subs_matrix, old_info_vec - _mv(old_precision, subs_vector)) const = _vv(subs_vector, old_info_vec - 0.5 * _mv(old_precision, subs_vector)) result = Gaussian(info_vec, precision, new_inputs) + Tensor( const, new_int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple((k, materialize(to_funsor(v, self.inputs[k]))) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Variables are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple((k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') if not (var_subs or int_subs or real_subs): return reflect(Subs, self, lazy_subs) # First perform any variable substitutions. if var_subs: rename = {k: v.name for k, v in var_subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.loc, self.precision, inputs) return Subs(var_result, int_subs + real_subs + lazy_subs) # Next perform any integer substitution, i.e. slicing into a batch. if int_subs: int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.loc, self.precision] funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, real_subs + lazy_subs) # Try to perform a complete substitution of all real variables, resulting in a Tensor. real_subs = OrderedDict(subs) assert real_subs and not int_subs if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): # Broadcast all component tensors. int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.loc, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(real_subs.values()) inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (loc, precision), values = tensors[:2], tensors[2:] # Form the concatenated value. offsets, event_size = _compute_offsets(self.inputs) value = BlockVector(batch_shape + (event_size, )) for k, value_k in zip(real_subs, values): offset = offsets[k] value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value_k.size(-1) == self.inputs[k].num_elements value_k = value_k.expand(batch_shape + value_k.shape[-1:]) value[..., offset:offset + self.inputs[k].num_elements] = value_k value = value.as_tensor() # Evaluate the non-normalized log density. result = -0.5 * _vmv(precision, value - loc) result = Tensor(result, inputs) assert result.output == reals() return Subs(result, lazy_subs) # Perform a partial substution of a subset of real variables, resulting in a Joint. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf raise NotImplementedError( 'TODO implement partial substitution of real variables')
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple( (k, v if isinstance(v, (Variable, Slice)) else materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Variables are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple( (k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') if not (var_subs or int_subs or real_subs): return reflect(Subs, self, lazy_subs) # First perform any variable substitutions. if var_subs: rename = {k: v.name for k, v in var_subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.info_vec, self.precision, inputs) return Subs(var_result, int_subs + real_subs + lazy_subs) # Next perform any integer substitution, i.e. slicing into a batch. if int_subs: int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.info_vec, self.precision] funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, real_subs + lazy_subs) # Broadcast all component tensors. real_subs = OrderedDict(subs) assert real_subs and not int_subs int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(real_subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(real_subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value.size(-1) == self.inputs[k].num_elements values[k] = value.expand(batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size, )) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == reals() return Subs(result, lazy_subs) # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in real_subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_ab = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_bb = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in b ], dim=-2) info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a], dim=-1) info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b], dim=-1) value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = prec_aa.expand(info_vec.shape + (-1, )) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in real_subs: inputs[k] = d return Gaussian(info_vec, precision, inputs) + Tensor( log_scale, int_inputs)