def eager_multinomial(total_count, probs, value): # Multinomial.log_prob() supports inhomogeneous total_count only by # avoiding passing total_count to the constructor. inputs, (total_count, probs, value) = align_tensors(total_count, probs, value) shape = broadcast_shape(total_count.shape + (1, ), probs.shape, value.shape) probs = Tensor(ops.expand(probs, shape), inputs) value = Tensor(ops.expand(value, shape), inputs) if get_backend() == "torch": total_count = Number( ops.amax(total_count, None).item()) # Used by distributions validation code. else: total_count = Tensor(ops.expand(total_count, shape[:-1]), inputs) backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) return backend_dist.Multinomial.eager_log_prob(total_count, probs, value) # noqa: F821
def einsum(equation, *operands): """ Forward-max-sum backward-argmax implementation of einsum. This assumes all operands have a ``._pyro_dims`` attribute set. """ inputs, output = equation.split('->') inputs = inputs.split(',') contract_dims = ''.join(sorted(set().union(*inputs) - set(output))) dims = output + contract_dims result = reduce(operator.add, broadcast_all(*operands, inputs=inputs, dims=dims)) if contract_dims: output_shape = result.shape[:len(output)] result = ops.amax(result.reshape(output_shape + (-1,)), -1) elif result is operands[0]: result = result[...] # create a new object assert len(result.shape) == len(output) return result
def einsum(equation, *operands): """ Log-sum-exp implementation of einsum. """ if get_backend() != "jax": # NB: rename symbols to support NumPy, which allow only symbols a-z. symbols = sorted(set(equation) - set(',->')) rename = dict(zip(symbols, 'abcdefghijklmnopqrstuvwxyz')) equation = ''.join(rename.get(s, s) for s in equation) inputs, output = equation.split('->') if inputs == output: return operands[0][...] # create a new object inputs = inputs.split(',') shifts = [] exp_operands = [] for dims, operand in zip(inputs, operands): shift = operand for i, dim in enumerate(dims): if dim not in output: shift = ops.amax(shift, i, keepdims=True) # avoid nan due to -inf - -inf shift = ops.clamp(shift, ops.finfo(shift).min, None) exp_operands.append(ops.exp(operand - shift)) # permute shift to match output shift = shift.reshape( [size for size, dim in zip(operand.shape, dims) if dim in output]) if len(shift.shape) > 0: shift = shift.reshape((1, ) * (len(output) - shift.ndim) + shift.shape) dims = [dim for dim in dims if dim in output] dims = [dim for dim in output if dim not in dims] + dims shift = ops.permute(shift, [dims.index(dim) for dim in output]) shifts.append(shift) result = ops.log(ops.einsum(equation, *exp_operands)) return sum(shifts + [result])