def broadcast_all(*values, **kwargs): """ Packed broadcasting of multiple tensors. """ inputs = kwargs.get('inputs') dims = kwargs.get('dims') sizes = { dim: size for value, old_dims in zip(values, inputs) for dim, size in zip(old_dims, value.shape) } if dims is None: dims = ''.join(sorted(sizes)) else: assert set(dims) == set(sizes) shape = tuple(sizes[dim] for dim in dims) values = list(values) for i, (x, old_dims) in enumerate(zip(values, inputs)): if old_dims != dims: x = ops.permute( x, tuple(old_dims.index(dim) for dim in dims if dim in old_dims)) x = x.reshape( tuple(sizes[dim] if dim in old_dims else 1 for dim in dims)) x = ops.expand(x, shape) assert len(x.shape) == len(dims) values[i] = x return tuple(values)
def tensor_to_data(x, name_to_dim=None): if not name_to_dim or not x.inputs: if x.inputs: raise ValueError( "cannot convert Tensor to data due to lazy inputs: {}".format( set(x.inputs))) return x.data else: assert all( isinstance(k, str) and isinstance(v, int) and v < 0 for k, v in name_to_dim.items()) # logic very similar to pyro.ops.packed.unpack # first collapse input domains into single dimensions data = x.data.reshape( tuple(d.dtype for d in x.inputs.values()) + x.output.shape) # permute packed dimensions to correct order unsorted_dims = [name_to_dim[name] for name in x.inputs] dims = sorted(unsorted_dims) permutation = [unsorted_dims.index(dim) for dim in dims] + \ list(range(len(dims), len(dims) + len(x.output.shape))) data = ops.permute(data, permutation) # expand batch_shape = [1] * -min(dims) for dim, size in zip(dims, data.shape): batch_shape[dim] = size return data.reshape(tuple(batch_shape) + x.output.shape)
def align(self, names): assert isinstance(names, tuple) assert all(name in self.inputs for name in names) if not names or names == tuple(self.inputs): return self inputs = OrderedDict((name, self.inputs[name]) for name in names) inputs.update(self.inputs) old_dims = tuple(self.inputs) new_dims = tuple(inputs) permutation = tuple(old_dims.index(d) for d in new_dims) permutation = permutation + tuple( range(len(permutation), len(permutation) + len(self.output.shape))) data = ops.permute(self.data, permutation) return Tensor(data, inputs, self.dtype)
def eager_getitem_tensor_variable(op, lhs, rhs): assert op.offset < len(lhs.output.shape) assert rhs.output == Bint[lhs.output.shape[op.offset]] assert rhs.name not in lhs.inputs # Convert a positional event dimension to a named batch dimension. inputs = lhs.inputs.copy() inputs[rhs.name] = rhs.output data = lhs.data target_dim = len(lhs.inputs) source_dim = target_dim + op.offset if target_dim != source_dim: perm = list(range(len(data.shape))) del perm[source_dim] perm.insert(target_dim, source_dim) data = ops.permute(data, perm) return Tensor(data, inputs, lhs.dtype)
def align_tensor(new_inputs, x, expand=False): r""" Permute and add dims to a tensor to match desired ``new_inputs``. :param OrderedDict new_inputs: A target set of inputs. :param funsor.terms.Funsor x: A :class:`Tensor` or :class:`~funsor.terms.Number` . :param bool expand: If False (default), set result size to 1 for any input of ``x`` not in ``new_inputs``; if True expand to ``new_inputs`` size. :return: a number or :class:`torch.Tensor` or :class:`np.ndarray` that can be broadcast to other tensors with inputs ``new_inputs``. :rtype: int or float or torch.Tensor or np.ndarray """ assert isinstance(new_inputs, OrderedDict) assert isinstance(x, (Number, Tensor)) assert all(isinstance(d.dtype, int) for d in x.inputs.values()) data = x.data if isinstance(x, Number): return data old_inputs = x.inputs if old_inputs == new_inputs: return data # Permute squashed input dims. x_keys = tuple(old_inputs) data = ops.permute( data, tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) + tuple(range(len(old_inputs), len(data.shape)))) # Unsquash multivariate input dims by filling in ones. data = data.reshape( tuple(old_inputs[k].dtype if k in old_inputs else 1 for k in new_inputs) + x.output.shape) # Optionally expand new dims. if expand: data = ops.expand( data, tuple(d.dtype for d in new_inputs.values()) + x.output.shape) return data
def einsum(equation, *operands): """ Log-sum-exp implementation of einsum. """ if get_backend() != "jax": # NB: rename symbols to support NumPy, which allow only symbols a-z. symbols = sorted(set(equation) - set(',->')) rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz')) equation = ''.join(rename.get(s, s) for s in equation) inputs, output = equation.split('->') if inputs == output: return operands[0][...] # create a new object inputs = inputs.split(',') shifts = [] exp_operands = [] for dims, operand in zip(inputs, operands): shift = operand for i, dim in enumerate(dims): if dim not in output: shift = ops.amax(shift, i, keepdims=True) # avoid nan due to -inf - -inf shift = ops.clamp(shift, ops.finfo(shift).min, None) exp_operands.append(ops.exp(operand - shift)) # permute shift to match output shift = shift.reshape( [size for size, dim in zip(operand.shape, dims) if dim in output]) if len(shift.shape) > 0: shift = shift.reshape((1, ) * (len(output) - shift.ndim) + shift.shape) dims = [dim for dim in dims if dim in output] dims = [dim for dim in output if dim not in dims] + dims shift = ops.permute(shift, [dims.index(dim) for dim in output]) shifts.append(shift) result = ops.log(ops.einsum(equation, *exp_operands)) return sum(shifts + [result])