def eager_getitem_tensor_tensor(op, lhs, rhs): assert op.offset < len(lhs.output.shape) assert rhs.output == Bint[lhs.output.shape[op.offset]] # Compute inputs and outputs. if lhs.inputs == rhs.inputs: inputs, lhs_data, rhs_data = lhs.inputs, lhs.data, rhs.data else: inputs, (lhs_data, rhs_data) = align_tensors(lhs, rhs) if len(lhs.output.shape) > 1: rhs_data = rhs_data.reshape(rhs_data.shape + (1, ) * (len(lhs.output.shape) - 1)) # Perform advanced indexing. lhs_data_dim = len(lhs_data.shape) target_dim = lhs_data_dim - len(lhs.output.shape) + op.offset index = [None] * lhs_data_dim for i in range(target_dim): index[i] = ops.new_arange( lhs_data, lhs_data.shape[i]).reshape((-1, ) + (1, ) * (lhs_data_dim - i - 2)) index[target_dim] = rhs_data for i in range(1 + target_dim, lhs_data_dim): index[i] = ops.new_arange( lhs_data, lhs_data.shape[i]).reshape((-1, ) + (1, ) * (lhs_data_dim - i - 1)) data = lhs_data[tuple(index)] return Tensor(data, inputs, lhs.dtype)
def new_arange(self, name, *args, **kwargs): """ Helper to create a named :func:`torch.arange` or :func:`np.arange` funsor. In some cases this can be replaced by a symbolic :class:`~funsor.terms.Slice` . :param str name: A variable name. :param int start: :param int stop: :param int step: Three args following :py:class:`slice` semantics. :param int dtype: An optional bounded integer type of this slice. :rtype: Tensor """ start = 0 step = 1 dtype = None if len(args) == 1: stop = args[0] dtype = kwargs.pop("dtype", stop) elif len(args) == 2: start, stop = args dtype = kwargs.pop("dtype", stop) elif len(args) == 3: start, stop, step = args dtype = kwargs.pop("dtype", stop) elif len(args) == 4: start, stop, step, dtype = args else: raise ValueError if step <= 0: raise ValueError stop = min(dtype, max(start, stop)) data = ops.new_arange(self.data, start, stop, step) inputs = OrderedDict([(name, Bint[len(data)])]) return Tensor(data, inputs, dtype=dtype)
def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim): time = Variable("time", bint(num_steps)) bias = Variable("bias", reals(num_sensors, dim)) bias_dist = random_gaussian( OrderedDict([ ("bias", reals(num_sensors, dim)), ])) trans = random_gaussian( OrderedDict([ ("time", bint(num_steps)), ("x_prev", reals(dim)), ("x_curr", reals(dim)), ])) obs = random_gaussian( OrderedDict([ ("time", bint(num_steps)), ("x_curr", reals(dim)), ("bias", reals(dim)), ])) # Each time step only a single sensor observes x, # and each sensor has a different bias. sensor_id = Tensor(ops.new_arange(get_default_prototype(), num_steps) % 2, OrderedDict(time=bint(num_steps)), dtype=2) with interpretation(eager_or_die): factor = trans + obs(bias=bias[sensor_id]) + bias_dist assert set(factor.inputs) == {"time", "bias", "x_prev", "x_curr"} result = sequential_sum_product(ops.logaddexp, ops.add, factor, time, {"x_prev": "x_curr"}) assert set(result.inputs) == {"bias", "x_prev", "x_curr"}
def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp: # Marginalize out real variables, but keep mixtures lazy. assert all(v in self.inputs for v in reduced_vars) real_vars = frozenset(k for k, d in self.inputs.items() if d.dtype == "real") reduced_reals = reduced_vars & real_vars reduced_ints = reduced_vars - real_vars if not reduced_reals: return None # defer to default implementation inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals) if reduced_reals == real_vars: result = self.log_normalizer else: int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != 'real') offsets, _ = _compute_offsets(self.inputs) a = [] b = [] for key, domain in self.inputs.items(): if domain.dtype == 'real': block = ops.new_arange( self.info_vec, offsets[key], offsets[key] + domain.num_elements, 1) (b if key in reduced_vars else a).append(block) a = ops.cat(-1, *a) b = ops.cat(-1, *b) prec_aa = self.precision[..., a[..., None], a] prec_ba = self.precision[..., b[..., None], a] prec_bb = self.precision[..., b[..., None], b] prec_b = ops.cholesky(prec_bb) prec_a = ops.triangular_solve(prec_ba, prec_b) prec_at = ops.transpose(prec_a, -1, -2) precision = prec_aa - ops.matmul(prec_at, prec_a) info_a = self.info_vec[..., a] info_b = self.info_vec[..., b] b_tmp = ops.triangular_solve(info_b[..., None], prec_b) info_vec = info_a - ops.matmul(prec_at, b_tmp)[..., 0] log_prob = Tensor( 0.5 * len(b) * math.log(2 * math.pi) - _log_det_tri(prec_b) + 0.5 * (b_tmp[..., 0]**2).sum(-1), int_inputs) result = log_prob + Gaussian(info_vec, precision, inputs) return result.reduce(ops.logaddexp, reduced_ints) elif op is ops.add: for v in reduced_vars: if self.inputs[v].dtype == 'real': raise ValueError( "Cannot sum along a real dimension: {}".format( repr(v))) # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian(). old_ints = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') new_ints = OrderedDict( (k, v) for k, v in old_ints.items() if k not in reduced_vars) inputs = OrderedDict((k, v) for k, v in self.inputs.items() if k not in reduced_vars) info_vec = Tensor(self.info_vec, old_ints).reduce(ops.add, reduced_vars) precision = Tensor(self.precision, old_ints).reduce(ops.add, reduced_vars) assert info_vec.inputs == new_ints assert precision.inputs == new_ints return Gaussian(info_vec.data, precision.data, inputs) return None # defer to default implementation
def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): assert self.output == Real sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self # Partition inputs into sample_inputs + batch_inputs + event_inputs. sample_inputs = OrderedDict( (k, d) for k, d in sample_inputs.items() if k not in self.inputs) sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) batch_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if k not in sampled_vars) event_inputs = OrderedDict( (k, d) for k, d in self.inputs.items() if k in sampled_vars) be_inputs = batch_inputs.copy() be_inputs.update(event_inputs) sb_inputs = sample_inputs.copy() sb_inputs.update(batch_inputs) # Sample all variables in a single Categorical call. logits = align_tensor(be_inputs, self) batch_shape = logits.shape[:len(batch_inputs)] flat_logits = logits.reshape(batch_shape + (-1, )) sample_shape = tuple(d.dtype for d in sample_inputs.values()) backend = get_backend() if backend != "numpy": from importlib import import_module dist = import_module( funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]) sample_args = (sample_shape, ) if rng_key is None else ( rng_key, sample_shape) flat_sample = dist.CategoricalLogits.dist_class( logits=flat_logits).sample(*sample_args) else: # default numpy backend assert backend == "numpy" shape = sample_shape + flat_logits.shape[:-1] logit_max = np.amax(flat_logits, -1, keepdims=True) probs = np.exp(flat_logits - logit_max) probs = probs / np.sum(probs, -1, keepdims=True) s = np.cumsum(probs, -1) r = np.random.rand(*shape) flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1) assert flat_sample.shape == sample_shape + batch_shape results = [] mod_sample = flat_sample for name, domain in reversed(list(event_inputs.items())): size = domain.dtype point = Tensor(mod_sample % size, sb_inputs, size) mod_sample = mod_sample // size results.append(Delta(name, point)) # Account for the log normalizer factor. # Derivation: Let f be a nonnormalized distribution (a funsor), and # consider operations in linear space (source code is in log space). # Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|. # f(x0) / |f| # dice numerator # Let g = delta(x=x0) |f| ----------------- # detach(f(x0)/|f|) # dice denominator # |detach(f)| f(x0) # = delta(x=x0) ----------------- be a dice approximation of f. # detach(f(x0)) # Then g is an unbiased estimator of f in value and all derivatives. # In the special case f = detach(f), we can simplify to # g = delta(x=x0) |f|. if (backend == "torch" and flat_logits.requires_grad) or backend == "jax": # Apply a dice factor to preserve differentiability. index = [ ops.new_arange(self.data, n).reshape((n, ) + (1, ) * (len(flat_logits.shape) - i - 2)) for i, n in enumerate(flat_logits.shape[:-1]) ] index.append(flat_sample) log_prob = flat_logits[tuple(index)] assert log_prob.shape == flat_sample.shape results.append( Tensor( ops.logsumexp(ops.detach(flat_logits), -1) + (log_prob - ops.detach(log_prob)), sb_inputs)) else: # This is the special case f = detach(f). results.append(Tensor(ops.logsumexp(flat_logits, -1), batch_inputs)) return reduce(ops.add, results)
def eager_subs(self, subs): assert isinstance(subs, tuple) subs = OrderedDict((k, to_funsor(v, self.inputs[k])) for k, v in subs if k in self.inputs) if not subs: return self # Handle diagonal variable substitution var_counts = Counter(v for v in subs.values() if isinstance(v, Variable)) subs = OrderedDict((k, self.materialize(v) if var_counts[v] > 1 else v) for k, v in subs.items()) # Handle renaming to enable cons hashing, and # handle slicing to avoid copying data. if any(isinstance(v, (Variable, Slice)) for v in subs.values()): slices = None inputs = OrderedDict() for i, (k, d) in enumerate(self.inputs.items()): if k in subs: v = subs[k] if isinstance(v, Variable): del subs[k] k = v.name elif isinstance(v, Slice): del subs[k] k = v.name d = v.inputs[v.name] if slices is None: slices = [slice(None)] * len(self.data.shape) slices[i] = v.slice inputs[k] = d data = self.data[tuple(slices)] if slices else self.data result = Tensor(data, inputs, self.dtype) return result.eager_subs(tuple(subs.items())) # materialize after checking for renaming case subs = OrderedDict((k, self.materialize(v)) for k, v in subs.items()) # Compute result shapes. inputs = OrderedDict() for k, domain in self.inputs.items(): if k in subs: inputs.update(subs[k].inputs) else: inputs[k] = domain # Construct a dict with each input's positional dim, # counting from the right so as to support broadcasting. total_size = len(inputs) + len( self.output.shape) # Assumes only scalar indices. new_dims = {} for k, domain in inputs.items(): assert not domain.shape new_dims[k] = len(new_dims) - total_size # Use advanced indexing to construct a simultaneous substitution. index = [] for k, domain in self.inputs.items(): if k in subs: v = subs.get(k) if isinstance(v, Number): index.append(int(v.data)) else: # Permute and expand v.data to end up at new_dims. assert isinstance(v, Tensor) v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs)) assert isinstance(v, Tensor) v_shape = [1] * total_size for k2, size in zip(v.inputs, v.data.shape): v_shape[new_dims[k2]] = size index.append(v.data.reshape(tuple(v_shape))) else: # Construct a [:] slice for this preserved input. offset_from_right = -1 - new_dims[k] index.append( ops.new_arange( self.data, domain.dtype).reshape((-1, ) + (1, ) * offset_from_right)) # Construct a [:] slice for the output. for i, size in enumerate(self.output.shape): offset_from_right = len(self.output.shape) - i - 1 index.append( ops.new_arange(self.data, size).reshape((-1, ) + (1, ) * offset_from_right)) data = self.data[tuple(index)] return Tensor(data, inputs, self.dtype)
def _scatter(src, res, subs): # inverse of advanced indexing # TODO check types of subs, in case some logic from eager_subs was accidentally left out? # use advanced indexing logic copied from Tensor.eager_subs: # materialize after checking for renaming case subs = OrderedDict((k, res.materialize(v)) for k, v in subs) # Compute result shapes. inputs = OrderedDict() for k, domain in res.inputs.items(): inputs[k] = domain # Construct a dict with each input's positional dim, # counting from the right so as to support broadcasting. total_size = len(inputs) + len( res.output.shape) # Assumes only scalar indices. new_dims = {} for k, domain in inputs.items(): assert not domain.shape new_dims[k] = len(new_dims) - total_size # Use advanced indexing to construct a simultaneous substitution. index = [] for k, domain in res.inputs.items(): if k in subs: v = subs.get(k) if isinstance(v, Number): index.append(int(v.data)) else: # Permute and expand v.data to end up at new_dims. assert isinstance(v, Tensor) v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs)) assert isinstance(v, Tensor) v_shape = [1] * total_size for k2, size in zip(v.inputs, v.data.shape): v_shape[new_dims[k2]] = size index.append(v.data.reshape(tuple(v_shape))) else: # Construct a [:] slice for this preserved input. offset_from_right = -1 - new_dims[k] index.append( ops.new_arange( res.data, domain.dtype).reshape((-1, ) + (1, ) * offset_from_right)) # Construct a [:] slice for the output. for i, size in enumerate(res.output.shape): offset_from_right = len(res.output.shape) - i - 1 index.append( ops.new_arange(res.data, size).reshape((-1, ) + (1, ) * offset_from_right)) # the only difference from Tensor.eager_subs is here: # instead of indexing the rhs (lhs = rhs[index]), we index the lhs (lhs[index] = rhs) # unsqueeze to make broadcasting work src_inputs, src_data = src.inputs.copy(), src.data for k, v in res.inputs.items(): if k not in src.inputs and isinstance(subs[k], Number): src_inputs[k] = bint(1) src_data = src_data.unsqueeze(-1 - len(src.output.shape)) src = Tensor(src_data, src_inputs, src.output.dtype).align(tuple(res.inputs.keys())) data = res.data data[tuple(index)] = src.data return Tensor(data, inputs, res.dtype)