def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # note this should handle transforms correctly via distribution_to_data raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()): dim_to_name[-d - len(raw_dist.batch_shape)] = name if value_name not in sampled_vars: return self sample_shape = tuple(v.size for v in sample_inputs.values()) sample_args = (sample_shape, ) if get_backend() == "torch" else ( rng_key, sample_shape) if self.has_rsample: raw_value = raw_dist.rsample(*sample_args) else: raw_value = ops.detach(raw_dist.sample(*sample_args)) funsor_value = to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) funsor_value = funsor_value.align( tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs)) result = funsor.delta.Delta(value_name, funsor_value) if not self.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample raw_log_prob = raw_dist.log_prob(raw_value) dice_factor = to_funsor(raw_log_prob - ops.detach(raw_log_prob), output=self.output, dim_to_name=dim_to_name) result = result + dice_factor return result
def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): params = OrderedDict(self.params) value = params.pop("value") assert all(isinstance(v, (Number, Tensor)) for v in params.values()) assert isinstance(value, Variable) and value.name in sampled_vars inputs_, tensors = align_tensors(*params.values()) inputs = OrderedDict(sample_inputs.items()) inputs.update(inputs_) sample_shape = tuple(v.size for v in sample_inputs.values()) raw_dist = self.dist_class(**dict(zip(self._ast_fields[:-1], tensors))) sample_args = (sample_shape, ) if rng_key is None else (rng_key, sample_shape) if getattr(raw_dist, "has_rsample", False): raw_sample = raw_dist.rsample(*sample_args) else: raw_sample = ops.detach(raw_dist.sample(*sample_args)) result = funsor.delta.Delta( value.name, Tensor(raw_sample, inputs, value.output.dtype)) if not getattr(raw_dist, "has_rsample", False): # scaling of dice_factor by num samples should already be handled by Funsor.sample raw_log_prob = raw_dist.log_prob(raw_sample) dice_factor = Tensor(raw_log_prob - ops.detach(raw_log_prob), inputs) result = result + dice_factor return result
def einsum(equation, *operands): """ Log-sum-exp implementation of einsum. """ if get_backend() != "jax": # NB: rename symbols to support NumPy, which allow only symbols a-z. symbols = sorted(set(equation) - set(',->')) rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz')) equation = ''.join(rename.get(s, s) for s in equation) inputs, output = equation.split('->') if inputs == output: return operands[0][...] # create a new object inputs = inputs.split(',') shifts = [] exp_operands = [] for dims, operand in zip(inputs, operands): shift = ops.detach(operand) for i, dim in enumerate(dims): if dim not in output: shift = ops.amax(shift, i, keepdims=True) # avoid nan due to -inf - -inf shift = ops.clamp(shift, ops.finfo(shift).min, None) exp_operands.append(ops.exp(operand - shift)) # permute shift to match output shift = shift.reshape( [size for size, dim in zip(operand.shape, dims) if dim in output]) if len(shift.shape) > 0: shift = shift.reshape((1, ) * (len(output) - shift.ndim) + shift.shape) dims = [dim for dim in dims if dim in output] dims = [dim for dim in output if dim not in dims] + dims shift = ops.permute(shift, [dims.index(dim) for dim in output]) shifts.append(shift) result = ops.log(ops.einsum(equation, *exp_operands)) return sum(shifts + [result])
def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): assert self.output == Real sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self # Partition inputs into sample_inputs + batch_inputs + event_inputs. sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) batch_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if k not in sampled_vars) event_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if k in sampled_vars) be_inputs = batch_inputs.copy() be_inputs.update(event_inputs) sb_inputs = sample_inputs.copy() sb_inputs.update(batch_inputs) # Sample all variables in a single Categorical call. logits = align_tensor(be_inputs, self) batch_shape = logits.shape[:len(batch_inputs)] flat_logits = logits.reshape(batch_shape + (-1, )) sample_shape = tuple(d.dtype for d in sample_inputs.values()) backend = get_backend() if backend != "numpy": from importlib import import_module dist = import_module( funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]) sample_args = (sample_shape, ) if rng_key is None else ( rng_key, sample_shape) flat_sample = dist.CategoricalLogits.dist_class( logits=flat_logits).sample(*sample_args) else: # default numpy backend assert backend == "numpy" shape = sample_shape + flat_logits.shape[:-1] logit_max = np.amax(flat_logits, -1, keepdims=True) probs = np.exp(flat_logits - logit_max) probs = probs / np.sum(probs, -1, keepdims=True) s = np.cumsum(probs, -1) r = np.random.rand(*shape) flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1) assert flat_sample.shape == sample_shape + batch_shape results = [] mod_sample = flat_sample for name, domain in reversed(list(event_inputs.items())): size = domain.dtype point = Tensor(mod_sample % size, sb_inputs, size) mod_sample = mod_sample // size results.append(Delta(name, point)) # Account for the log normalizer factor. # Derivation: Let f be a nonnormalized distribution (a funsor), and # consider operations in linear space (source code is in log space). # Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|. # f(x0) / |f| # dice numerator # Let g = delta(x=x0) |f| ----------------- # detach(f(x0)/|f|) # dice denominator # |detach(f)| f(x0) # = delta(x=x0) ----------------- be a dice approximation of f. # detach(f(x0)) # Then g is an unbiased estimator of f in value and all derivatives. # In the special case f = detach(f), we can simplify to # g = delta(x=x0) |f|. if (backend == "torch" and flat_logits.requires_grad) or backend == "jax": # Apply a dice factor to preserve differentiability. index = [ ops.new_arange(self.data, n).reshape((n, ) + (1, ) * (len(flat_logits.shape) - i - 2)) for i, n in enumerate(flat_logits.shape[:-1]) ] index.append(flat_sample) log_prob = flat_logits[tuple(index)] assert log_prob.shape == flat_sample.shape results.append( Tensor( ops.logsumexp(ops.detach(flat_logits), -1) + (log_prob - ops.detach(log_prob)), sb_inputs)) else: # This is the special case f = detach(f). results.append(Tensor(ops.logsumexp(flat_logits, -1), batch_inputs)) return reduce(ops.add, results)