def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple( (k, v if isinstance(v, (Variable, Slice)) else materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Affine funsors are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple( (k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice)) and not (is_affine(v) and affine_inputs(v))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') affine_subs = tuple((k, v) for k, v in subs if is_affine(v) and affine_inputs(v) and not isinstance(v, Variable)) if var_subs: return self._eager_subs_var( var_subs, int_subs + real_subs + affine_subs + lazy_subs) if int_subs: return self._eager_subs_int(int_subs, real_subs + affine_subs + lazy_subs) if real_subs: return self._eager_subs_real(real_subs, affine_subs + lazy_subs) if affine_subs: return self._eager_subs_affine(affine_subs, lazy_subs) return reflect(Subs, self, lazy_subs)
def test_gaussian_mixture_distribution(batch_inputs, event_inputs): num_samples = 100000 sample_inputs = OrderedDict(particle=bint(num_samples)) be_inputs = OrderedDict(batch_inputs + event_inputs) int_inputs = OrderedDict( (k, d) for k, d in be_inputs.items() if d.dtype != 'real') batch_inputs = OrderedDict(batch_inputs) event_inputs = OrderedDict(event_inputs) sampled_vars = frozenset(['f']) p = random_gaussian(be_inputs) + 0.5 * random_tensor(int_inputs) p_marginal = p.reduce(ops.logaddexp, 'e') assert isinstance(p_marginal, Tensor) q = p.sample(sampled_vars, sample_inputs) q_marginal = q.reduce(ops.logaddexp, 'e') q_marginal = materialize(q_marginal).reduce(ops.logaddexp, 'particle') assert isinstance(q_marginal, Tensor) q_marginal = q_marginal.align(tuple(p_marginal.inputs)) assert_close(q_marginal, p_marginal, atol=0.1, rtol=None)
def test_tensor_distribution(event_inputs, batch_inputs, test_grad): num_samples = 50000 sample_inputs = OrderedDict(n=bint(num_samples)) be_inputs = OrderedDict(batch_inputs + event_inputs) batch_inputs = OrderedDict(batch_inputs) event_inputs = OrderedDict(event_inputs) sampled_vars = frozenset(event_inputs) p = random_tensor(be_inputs) p.data.requires_grad_(test_grad) q = p.sample(sampled_vars, sample_inputs) mq = materialize(q).reduce(ops.logaddexp, 'n') mq = mq.align(tuple(p.inputs)) assert_close(mq, p, atol=0.1, rtol=None) if test_grad: _, (p_data, mq_data) = align_tensors(p, mq) assert p_data.shape == mq_data.shape probe = torch.randn(p_data.shape) expected = grad((p_data.exp() * probe).sum(), [p.data])[0] actual = grad((mq_data.exp() * probe).sum(), [p.data])[0] assert_close(actual, expected, atol=0.1, rtol=None)
def eager_categorical(probs, value): value = materialize(value) return Categorical.eager_log_prob(probs=probs, value=value)
def _scatter(src, res, subs): # inverse of advanced indexing # TODO check types of subs, in case some logic from eager_subs was accidentally left out? # use advanced indexing logic copied from Tensor.eager_subs: # materialize after checking for renaming case subs = OrderedDict((k, materialize(v)) for k, v in subs) # Compute result shapes. inputs = OrderedDict() for k, domain in res.inputs.items(): inputs[k] = domain # Construct a dict with each input's positional dim, # counting from the right so as to support broadcasting. total_size = len(inputs) + len(res.output.shape) # Assumes only scalar indices. new_dims = {} for k, domain in inputs.items(): assert not domain.shape new_dims[k] = len(new_dims) - total_size # Use advanced indexing to construct a simultaneous substitution. index = [] for k, domain in res.inputs.items(): if k in subs: v = subs.get(k) if isinstance(v, Number): index.append(int(v.data)) else: # Permute and expand v.data to end up at new_dims. assert isinstance(v, Tensor) v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs)) assert isinstance(v, Tensor) v_shape = [1] * total_size for k2, size in zip(v.inputs, v.data.shape): v_shape[new_dims[k2]] = size index.append(v.data.reshape(tuple(v_shape))) else: # Construct a [:] slice for this preserved input. offset_from_right = -1 - new_dims[k] index.append(torch.arange(domain.dtype).reshape( (-1,) + (1,) * offset_from_right)) # Construct a [:] slice for the output. for i, size in enumerate(res.output.shape): offset_from_right = len(res.output.shape) - i - 1 index.append(torch.arange(size).reshape( (-1,) + (1,) * offset_from_right)) # the only difference from Tensor.eager_subs is here: # instead of indexing the rhs (lhs = rhs[index]), we index the lhs (lhs[index] = rhs) # unsqueeze to make broadcasting work src_inputs, src_data = src.inputs.copy(), src.data for k, v in res.inputs.items(): if k not in src.inputs and isinstance(subs[k], Number): src_inputs[k] = bint(1) src_data = src_data.unsqueeze(-1 - len(src.output.shape)) src = Tensor(src_data, src_inputs, src.output.dtype).align(tuple(res.inputs.keys())) data = res.data data[tuple(index)] = src.data return Tensor(data, inputs, res.dtype)
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple((k, materialize(to_funsor(v, self.inputs[k]))) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Variables are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple((k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') if not (var_subs or int_subs or real_subs): return reflect(Subs, self, lazy_subs) # First perform any variable substitutions. if var_subs: rename = {k: v.name for k, v in var_subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.loc, self.precision, inputs) return Subs(var_result, int_subs + real_subs + lazy_subs) # Next perform any integer substitution, i.e. slicing into a batch. if int_subs: int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.loc, self.precision] funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, real_subs + lazy_subs) # Try to perform a complete substitution of all real variables, resulting in a Tensor. real_subs = OrderedDict(subs) assert real_subs and not int_subs if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): # Broadcast all component tensors. int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.loc, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(real_subs.values()) inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (loc, precision), values = tensors[:2], tensors[2:] # Form the concatenated value. offsets, event_size = _compute_offsets(self.inputs) value = BlockVector(batch_shape + (event_size, )) for k, value_k in zip(real_subs, values): offset = offsets[k] value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value_k.size(-1) == self.inputs[k].num_elements value_k = value_k.expand(batch_shape + value_k.shape[-1:]) value[..., offset:offset + self.inputs[k].num_elements] = value_k value = value.as_tensor() # Evaluate the non-normalized log density. result = -0.5 * _vmv(precision, value - loc) result = Tensor(result, inputs) assert result.output == reals() return Subs(result, lazy_subs) # Perform a partial substution of a subset of real variables, resulting in a Joint. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf raise NotImplementedError( 'TODO implement partial substitution of real variables')
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = tuple( (k, v if isinstance(v, (Variable, Slice)) else materialize(v)) for k, v in subs if k in self.inputs) if not subs: return self # Constants and Variables are eagerly substituted; # everything else is lazily substituted. lazy_subs = tuple( (k, v) for k, v in subs if not isinstance(v, (Number, Tensor, Variable, Slice))) var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable)) int_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor, Slice)) if v.dtype != 'real') real_subs = tuple((k, v) for k, v in subs if isinstance(v, (Number, Tensor)) if v.dtype == 'real') if not (var_subs or int_subs or real_subs): return reflect(Subs, self, lazy_subs) # First perform any variable substitutions. if var_subs: rename = {k: v.name for k, v in var_subs} inputs = OrderedDict( (rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") var_result = Gaussian(self.info_vec, self.precision, inputs) return Subs(var_result, int_subs + real_subs + lazy_subs) # Next perform any integer substitution, i.e. slicing into a batch. if int_subs: int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype == 'real') tensors = [self.info_vec, self.precision] funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, real_subs + lazy_subs) # Broadcast all component tensors. real_subs = OrderedDict(subs) assert real_subs and not int_subs int_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if d.dtype != 'real') tensors = [ Tensor(self.info_vec, int_inputs), Tensor(self.precision, int_inputs) ] tensors.extend(real_subs.values()) int_inputs, tensors = align_tensors(*tensors) batch_dim = tensors[0].dim() - 1 batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) (info_vec, precision), values = tensors[:2], tensors[2:] offsets, event_size = _compute_offsets(self.inputs) slices = [(k, slice(offset, offset + self.inputs[k].num_elements)) for k, offset in offsets.items()] # Expand all substituted values. values = OrderedDict(zip(real_subs, values)) for k, value in values.items(): value = value.reshape(value.shape[:batch_dim] + (-1, )) if not torch._C._get_tracing_state(): assert value.size(-1) == self.inputs[k].num_elements values[k] = value.expand(batch_shape + value.shape[-1:]) # Try to perform a complete substitution of all real variables, resulting in a Tensor. if all(k in real_subs for k, d in self.inputs.items() if d.dtype == 'real'): # Form the concatenated value. value = BlockVector(batch_shape + (event_size, )) for k, i in slices: if k in values: value[..., i] = values[k] value = value.as_tensor() # Evaluate the non-normalized log density. result = _vv(value, info_vec - 0.5 * _mv(precision, value)) result = Tensor(result, int_inputs) assert result.output == reals() return Subs(result, lazy_subs) # Perform a partial substution of a subset of real variables, resulting in a Joint. # We split real inputs into two sets: a for the preserved and b for the substituted. b = frozenset(k for k, v in real_subs.items()) a = frozenset(k for k, d in self.inputs.items() if d.dtype == 'real' and k not in b) prec_aa = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_ab = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in a ], dim=-2) prec_bb = torch.cat([ torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b], dim=-1) for k1, i1 in slices if k1 in b ], dim=-2) info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a], dim=-1) info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b], dim=-1) value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1) info_vec = info_a - _mv(prec_ab, value_b) log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b)) precision = prec_aa.expand(info_vec.shape + (-1, )) inputs = int_inputs.copy() for k, d in self.inputs.items(): if k not in real_subs: inputs[k] = d return Gaussian(info_vec, precision, inputs) + Tensor( log_scale, int_inputs)