def test_smoke(expr, expected_type): g1 = Gaussian(loc=torch.tensor([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]), precision=torch.tensor([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]], [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) assert isinstance(g1, Gaussian) g2 = Gaussian(loc=torch.tensor([[0.0, 0.1], [2.0, 3.0]]), precision=torch.tensor([[[1.0, 0.2], [0.2, 1.0]], [[1.0, 0.2], [0.2, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('y', reals(2))])) assert isinstance(g2, Gaussian) shift = Tensor(torch.tensor([-1., 1.]), OrderedDict([('i', bint(2))])) assert isinstance(shift, Tensor) i0 = Number(1, 2) assert isinstance(i0, Number) x0 = Tensor(torch.tensor([0.5, 0.6, 0.7])) assert isinstance(x0, Tensor) y0 = Tensor(torch.tensor([[0.2, 0.3], [0.8, 0.9]]), inputs=OrderedDict([('i', bint(2))])) assert isinstance(y0, Tensor) result = eval(expr) assert isinstance(result, expected_type)
def test_tensordot(x_shape, xy_shape, y_shape): x = torch.randn(x_shape + xy_shape) y = torch.randn(xy_shape + y_shape) dim = len(xy_shape) actual = torch_tensordot(Tensor(x), Tensor(y), dim) expected = Tensor(torch.tensordot(x, y, dim)) assert_close(actual, expected, atol=1e-5, rtol=None)
def eager_integrate(log_measure, integrand, reduced_vars): real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') if real_vars: lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real') rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real') if lhs_reals == real_vars and rhs_reals <= real_vars: inputs = OrderedDict((k, d) for t in (log_measure, integrand) for k, d in t.inputs.items()) lhs_loc, lhs_precision = align_gaussian(inputs, log_measure) rhs_loc, rhs_precision = align_gaussian(inputs, integrand) # Compute the expectation of a non-normalized quadratic form. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf lhs_scale_tri = torch.inverse( torch.cholesky(lhs_precision)).transpose(-1, -2) lhs_covariance = torch.matmul(lhs_scale_tri, lhs_scale_tri.transpose(-1, -2)) dim = lhs_loc.size(-1) norm = _det_tri(lhs_scale_tri) * (2 * math.pi)**(0.5 * dim) data = -0.5 * norm * (_vmv(rhs_precision, lhs_loc - rhs_loc) + _trace_mm(rhs_precision, lhs_covariance)) inputs = OrderedDict( (k, d) for k, d in inputs.items() if k not in reduced_vars) result = Tensor(data, inputs) return result.reduce(ops.add, reduced_vars - real_vars) raise NotImplementedError('TODO implement partial integration') return None # defer to default implementation
def test_multinomial_density(batch_shape, event_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) max_count = 10 @funsor.torch.function(reals(), reals(*event_shape), reals(*event_shape), reals()) def multinomial(total_count, probs, value): total_count = total_count.max().item() return torch.distributions.Multinomial(total_count, probs).log_prob(value) check_funsor( multinomial, { 'total_count': reals(), 'probs': reals(*event_shape), 'value': reals(*event_shape) }, reals()) probs_data = torch.rand(batch_shape + event_shape) probs_data = probs_data / probs_data.sum(-1, keepdim=True) probs = Tensor(probs_data, inputs) value_data = torch.randint(0, max_count, size=batch_shape + event_shape).float() total_count_data = value_data.sum(-1) + torch.randint( 0, max_count, size=batch_shape).float() value = Tensor(value_data, inputs) total_count = Tensor(total_count_data, inputs) expected = multinomial(total_count, probs, value) check_funsor(expected, inputs, reals()) actual = dist.Multinomial(total_count, probs, value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)
def eager_integrate(log_measure, integrand, reduced_vars): real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') if real_vars: lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real') rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real') if lhs_reals == real_vars and rhs_reals <= real_vars: inputs = OrderedDict((k, d) for t in (log_measure, integrand) for k, d in t.inputs.items()) lhs_info_vec, lhs_precision = align_gaussian(inputs, log_measure) rhs_info_vec, rhs_precision = align_gaussian(inputs, integrand) lhs = Gaussian(lhs_info_vec, lhs_precision, inputs) # Compute the expectation of a non-normalized quadratic form. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf norm = lhs.log_normalizer.data.exp() lhs_cov = cholesky_inverse(lhs._precision_chol) lhs_loc = lhs.info_vec.unsqueeze(-1).cholesky_solve( lhs._precision_chol).squeeze(-1) vmv_term = _vv(lhs_loc, rhs_info_vec - 0.5 * _mv(rhs_precision, lhs_loc)) data = norm * (vmv_term - 0.5 * _trace_mm(rhs_precision, lhs_cov)) inputs = OrderedDict( (k, d) for k, d in inputs.items() if k not in reduced_vars) result = Tensor(data, inputs) return result.reduce(ops.add, reduced_vars - real_vars) raise NotImplementedError('TODO implement partial integration') return None # defer to default implementation
def test_dirichlet_multinomial_density(batch_shape, event_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) max_count = 10 @funsor.torch.function(reals(*event_shape), reals(), reals(*event_shape), reals()) def dirichlet_multinomial(concentration, total_count, value): return pyro.distributions.DirichletMultinomial( concentration, total_count).log_prob(value) check_funsor( dirichlet_multinomial, { 'concentration': reals(*event_shape), 'total_count': reals(), 'value': reals(*event_shape) }, reals()) concentration = Tensor( torch.randn(batch_shape + event_shape).exp(), inputs) value_data = torch.randint(0, max_count, size=batch_shape + event_shape).float() total_count_data = value_data.sum(-1) + torch.randint( 0, max_count, size=batch_shape).float() value = Tensor(value_data, inputs) total_count = Tensor(total_count_data, inputs) expected = dirichlet_multinomial(concentration, total_count, value) check_funsor(expected, inputs, reals()) actual = dist.DirichletMultinomial(concentration, total_count, value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)
def funsor_to_mvn(gaussian, ndims, event_inputs=()): """ Convert a :class:`~funsor.terms.Funsor` to a :class:`pyro.distributions.MultivariateNormal` , dropping the normalization constant. :param gaussian: A Gaussian funsor. :type gaussian: funsor.gaussian.Gaussian or funsor.joint.Joint :param int ndims: The number of batch dimensions in the result. :param tuple event_inputs: A tuple of names to assign to rightmost dimensions. :return: a multivariate normal distribution. :rtype: pyro.distributions.MultivariateNormal """ assert sum(1 for d in gaussian.inputs.values() if d.dtype == "real") == 1 if isinstance(gaussian, Contraction): gaussian = [v for v in gaussian.terms if isinstance(v, Gaussian)][0] assert isinstance(gaussian, Gaussian) precision = gaussian.precision loc = cholesky_solve(gaussian.info_vec.unsqueeze(-1), cholesky(precision)).squeeze(-1) int_inputs = OrderedDict( (k, d) for k, d in gaussian.inputs.items() if d.dtype != "real") loc = Tensor(loc, int_inputs) precision = Tensor(precision, int_inputs) assert len(loc.output.shape) == 1 assert precision.output.shape == loc.output.shape * 2 loc = funsor_to_tensor(loc, ndims + 1, event_inputs) precision = funsor_to_tensor(precision, ndims + 2, event_inputs) return pyro.distributions.MultivariateNormal(loc, precision_matrix=precision)
def test_beta_density(batch_shape, eager): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @funsor.torch.function(reals(), reals(), reals(), reals()) def beta(concentration1, concentration0, value): return torch.distributions.Beta(concentration1, concentration0).log_prob(value) check_funsor(beta, { 'concentration1': reals(), 'concentration0': reals(), 'value': reals() }, reals()) concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs) concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs) value = Tensor(torch.rand(batch_shape), inputs) expected = beta(concentration1, concentration0, value) check_funsor(expected, inputs, reals()) d = Variable('value', reals()) actual = dist.Beta(concentration1, concentration0, value) if eager else \ dist.Beta(concentration1, concentration0, d)(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)
def test_reduce_moment_matching_multivariate(): int_inputs = [('i', bint(4))] real_inputs = [('x', reals(2))] inputs = OrderedDict(int_inputs + real_inputs) int_inputs = OrderedDict(int_inputs) real_inputs = OrderedDict(real_inputs) loc = torch.tensor([[-10., -1.], [+10., -1.], [+10., +1.], [-10., +1.]]) precision = torch.zeros(4, 1, 1) + torch.eye(2, 2) discrete = Tensor(torch.zeros(4), int_inputs) gaussian = Gaussian(loc, precision, inputs) gaussian -= gaussian.log_normalizer joint = discrete + gaussian with interpretation(moment_matching): actual = joint.reduce(ops.logaddexp, 'i') assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp)) expected_loc = torch.zeros(2) expected_covariance = torch.tensor([[101., 0.], [0., 2.]]) expected_precision = expected_covariance.inverse() expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs) expected_gaussian -= expected_gaussian.log_normalizer expected_discrete = Tensor(torch.tensor(4.).log()) expected = expected_discrete + expected_gaussian assert_close(actual, expected, atol=1e-5, rtol=None)
def test_advanced_indexing_shape(): I, J, M, N = 4, 4, 2, 3 x = Tensor(torch.randn(I, J), OrderedDict([ ('i', bint(I)), ('j', bint(J)), ])) m = Tensor(torch.tensor([2, 3]), OrderedDict([('m', bint(M))]), I) n = Tensor(torch.tensor([0, 1, 1]), OrderedDict([('n', bint(N))]), J) assert x.data.shape == (I, J) check_funsor(x(i=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(i=m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(i=m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(i=m, k=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(i=n), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(i=n, k=m), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(j=m), {'i': bint(I), 'm': bint(M)}, reals()) check_funsor(x(j=m, i=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(j=m, i=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(j=m, k=m), {'i': bint(I), 'm': bint(M)}, reals()) check_funsor(x(j=n), {'i': bint(I), 'n': bint(N)}, reals()) check_funsor(x(j=n, k=m), {'i': bint(I), 'n': bint(N)}, reals()) check_funsor(x(m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(m, j=n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, j=n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, k=m), {'j': bint(J), 'm': bint(M)}, reals()) check_funsor(x(m, n), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(m, n, k=m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(n), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(n, k=m), {'j': bint(J), 'n': bint(N)}, reals()) check_funsor(x(n, m), {'m': bint(M), 'n': bint(N)}, reals()) check_funsor(x(n, m, k=m), {'m': bint(M), 'n': bint(N)}, reals())
def test_binomial_density(batch_shape, eager): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) max_count = 10 @funsor.torch.function(reals(), reals(), reals(), reals()) def binomial(total_count, probs, value): return torch.distributions.Binomial(total_count, probs).log_prob(value) check_funsor(binomial, { 'total_count': reals(), 'probs': reals(), 'value': reals() }, reals()) value_data = random_tensor(inputs, bint(max_count)).data.float() total_count_data = value_data + random_tensor( inputs, bint(max_count)).data.float() value = Tensor(value_data, inputs) total_count = Tensor(total_count_data, inputs) probs = Tensor(torch.rand(batch_shape), inputs) expected = binomial(total_count, probs, value) check_funsor(expected, inputs, reals()) m = Variable('value', reals()) actual = dist.Binomial(total_count, probs, value) if eager else \ dist.Binomial(total_count, probs, m)(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)
def test_reduce_moment_matching_univariate(): int_inputs = [('i', bint(2))] real_inputs = [('x', reals())] inputs = OrderedDict(int_inputs + real_inputs) int_inputs = OrderedDict(int_inputs) real_inputs = OrderedDict(real_inputs) p = 0.8 t = 1.234 s1, s2, s3 = 2.0, 3.0, 4.0 loc = torch.tensor([[-s1], [s1]]) precision = torch.tensor([[[s2**-2]], [[s3**-2]]]) info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) discrete = Tensor(torch.tensor([1 - p, p]).log() + t, int_inputs) gaussian = Gaussian(info_vec, precision, inputs) gaussian -= gaussian.log_normalizer joint = discrete + gaussian with interpretation(moment_matching): actual = joint.reduce(ops.logaddexp, 'i') assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp)) expected_loc = torch.tensor([(2 * p - 1) * s1]) expected_variance = (4 * p * (1 - p) * s1**2 + (1 - p) * s2**2 + p * s3**2) expected_precision = torch.tensor([[1 / expected_variance]]) expected_info_vec = expected_precision.matmul( expected_loc.unsqueeze(-1)).squeeze(-1) expected_gaussian = Gaussian(expected_info_vec, expected_precision, real_inputs) expected_gaussian -= expected_gaussian.log_normalizer expected_discrete = Tensor(torch.tensor(t)) expected = expected_discrete + expected_gaussian assert_close(actual, expected, atol=1e-5, rtol=None)
def __call__(self): # calls pyro.param so that params are exposed and constraints applied # should not create any new torch.Tensors after __init__ self.initialize_params() N_c = self.config["sizes"]["group"] N_s = self.config["sizes"]["individual"] log_prob = Tensor(torch.tensor(0.), OrderedDict()) plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))])) plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))])) if self.config["group"]["random"] == "continuous": eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])( value="eps_g") log_prob += eps_g_dist # individual-level random effects if self.config["individual"]["random"] == "continuous": eps_i_dist = plate_g + plate_i + dist.Normal( **self.params["eps_i"])(value="eps_i") log_prob += eps_i_dist return log_prob
def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gaussian): approx_vars = frozenset( k for k in reduced_vars if k in gaussian.inputs and gaussian.inputs[k].dtype != 'real') exact_vars = reduced_vars - approx_vars if exact_vars and approx_vars: return Contraction(red_op, bin_op, exact_vars, discrete, gaussian).reduce(red_op, approx_vars) if approx_vars and not exact_vars: discrete += gaussian.log_normalizer new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) num_elements = reduce(ops.mul, [ gaussian.inputs[k].num_elements for k in approx_vars.difference(discrete.inputs) ], 1) if num_elements != 1: new_discrete -= math.log(num_elements) int_inputs = OrderedDict( (k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real') probs = (discrete - new_discrete.clamp_finite()).exp() old_loc = Tensor( gaussian.info_vec.unsqueeze(-1).cholesky_solve( gaussian._precision_chol).squeeze(-1), int_inputs) new_loc = (probs * old_loc).reduce(ops.add, approx_vars) old_cov = Tensor(cholesky_inverse(gaussian._precision_chol), int_inputs) diff = old_loc - new_loc outers = Tensor( diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs) new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) + (probs * outers).reduce(ops.add, approx_vars)) # Numerically stabilize by adding bogus precision to empty components. total = probs.reduce(ops.add, approx_vars) mask = (total.data == 0).to( total.data.dtype).unsqueeze(-1).unsqueeze(-1) new_cov.data += mask * torch.eye(new_cov.data.size(-1)) new_precision = Tensor(cholesky_inverse(cholesky(new_cov.data)), new_cov.inputs) new_info_vec = new_precision.data.matmul( new_loc.data.unsqueeze(-1)).squeeze(-1) new_inputs = new_loc.inputs.copy() new_inputs.update( (k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real') new_gaussian = Gaussian(new_info_vec, new_precision.data, new_inputs) new_discrete -= new_gaussian.log_normalizer return new_discrete + new_gaussian return None
def test_delta_density(batch_shape, event_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @funsor.torch.function(reals(*event_shape), reals(), reals(*event_shape), reals()) def delta(v, log_density, value): eq = (v == value) for _ in range(len(event_shape)): eq = eq.all(dim=-1) return eq.type(v.dtype).log() + log_density check_funsor( delta, { 'v': reals(*event_shape), 'log_density': reals(), 'value': reals(*event_shape) }, reals()) v = Tensor(torch.randn(batch_shape + event_shape), inputs) log_density = Tensor(torch.randn(batch_shape).exp(), inputs) for value in [v, Tensor(torch.randn(batch_shape + event_shape), inputs)]: expected = delta(v, log_density, value) check_funsor(expected, inputs, reals()) actual = dist.Delta(v, log_density, value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)
def test_distributions(state_dim, obs_dim): data = Tensor(torch.randn(2, obs_dim))["time"] bias = Variable("bias", reals(obs_dim)) bias_dist = dist_to_funsor(random_mvn((), obs_dim))(value=bias) prev = Variable("prev", reals(state_dim)) curr = Variable("curr", reals(state_dim)) trans_mat = Tensor( torch.eye(state_dim) + 0.1 * torch.randn(state_dim, state_dim)) trans_mvn = random_mvn((), state_dim) trans_dist = dist.MultivariateNormal(loc=trans_mvn.loc, scale_tril=trans_mvn.scale_tril, value=curr - prev @ trans_mat) state = Variable("state", reals(state_dim)) obs = Variable("obs", reals(obs_dim)) obs_mat = Tensor(torch.randn(state_dim, obs_dim)) obs_mvn = random_mvn((), obs_dim) obs_dist = dist.MultivariateNormal(loc=obs_mvn.loc, scale_tril=obs_mvn.scale_tril, value=state @ obs_mat + bias - obs) log_prob = 0 log_prob += bias_dist state_0 = Variable("state_0", reals(state_dim)) log_prob += obs_dist(state=state_0, obs=data(time=0)) state_1 = Variable("state_1", reals(state_dim)) log_prob += trans_dist(prev=state_0, curr=state_1) log_prob += obs_dist(state=state_1, obs=data(time=1)) log_prob = log_prob.reduce(ops.logaddexp) assert isinstance(log_prob, Tensor), log_prob.pretty()
def test_smoke(expr, expected_type): dx = Delta('x', Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2))]))) assert isinstance(dx, Delta) dy = Delta('y', Tensor(torch.randn(3, 4), OrderedDict([('j', bint(3))]))) assert isinstance(dy, Delta) t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) assert isinstance(t, Tensor) g = Gaussian(info_vec=torch.tensor([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]), precision=torch.tensor([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]], [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) assert isinstance(g, Gaussian) i0 = Number(1, 2) assert isinstance(i0, Number) x0 = Tensor(torch.tensor([0.5, 0.6, 0.7])) assert isinstance(x0, Tensor) result = eval(expr) assert isinstance(result, expected_type)
def test_reduce_subset(dims, reduced_vars, op): reduced_vars = frozenset(reduced_vars) sizes = {'a': 3, 'b': 4, 'c': 5} shape = tuple(sizes[d] for d in dims) inputs = OrderedDict((d, bint(sizes[d])) for d in dims) data = torch.rand(shape) + 0.5 dtype = 'real' if op in [ops.and_, ops.or_]: data = data.byte() dtype = 2 x = Tensor(data, inputs, dtype) actual = x.reduce(op, reduced_vars) expected_inputs = OrderedDict( (d, bint(sizes[d])) for d in dims if d not in reduced_vars) reduced_vars &= frozenset(dims) if not reduced_vars: assert actual is x else: if reduced_vars == frozenset(dims): if op is ops.logaddexp: # work around missing torch.Tensor.logsumexp() data = data.reshape(-1).logsumexp(0) else: data = REDUCE_OP_TO_TORCH[op](data) else: for pos in reversed(sorted(map(dims.index, reduced_vars))): data = REDUCE_OP_TO_TORCH[op](data, pos) if op in (ops.min, ops.max): data = data[0] check_funsor(actual, expected_inputs, Domain((), dtype)) assert_close(actual, Tensor(data, expected_inputs, dtype), atol=1e-5, rtol=1e-5)
def test_lambda_getitem(): data = torch.randn(2) x = Tensor(data) y = Tensor(data, OrderedDict(i=bint(2))) i = Variable('i', bint(2)) assert x[i] is y assert Lambda(i, y) is x
def test_function_of_torch_tensor(): x = torch.randn(4, 3) y = torch.randn(3, 2) f = funsor.torch.function(reals(4, 3), reals(3, 2), reals(4, 2))(torch.matmul) actual = f(x, y) expected = f(Tensor(x), Tensor(y)) assert_close(actual, expected)
def test_getitem_variable(): data = torch.randn((5, 4, 3, 2)) x = Tensor(data) i = Variable('i', bint(5)) j = Variable('j', bint(4)) assert x[i] is Tensor(data, OrderedDict([('i', bint(5))])) assert x[i, j] is Tensor(data, OrderedDict([('i', bint(5)), ('j', bint(4))]))
def test_mvn_affine_reshape(): x = Variable('x', reals(2, 2)) y = Variable('y', reals(4)) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(4))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 4)) d = d(value=x.reshape((4, )) - y) _check_mvn_affine(d, data)
def test_mvn_affine_two_vars(): x = Variable('x', reals(2)) y = Variable('y', reals(2)) data = dict(x=Tensor(torch.randn(2)), y=Tensor(torch.randn(2))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 2)) d = d(value=x - y) _check_mvn_affine(d, data)
def test_mvn_affine_einsum(): c = Tensor(torch.randn(3, 2, 2)) x = Variable('x', reals(2, 2)) y = Variable('y', reals()) data = dict(x=Tensor(torch.randn(2, 2)), y=Tensor(torch.randn(()))) with interpretation(lazy): d = dist_to_funsor(random_mvn((), 3)) d = d(value=Einsum("abc,bc->a", c, x) + y) _check_mvn_affine(d, data)
def test_einsum(equation): sizes = dict(a=2, b=3, c=4) inputs, outputs = equation.split('->') inputs = inputs.split(',') tensors = [torch.randn(tuple(sizes[d] for d in dims)) for dims in inputs] funsors = [Tensor(x) for x in tensors] expected = Tensor(torch.einsum(equation, *tensors)) actual = Einsum(equation, tuple(funsors)) assert_close(actual, expected, atol=1e-5, rtol=None)
def test_reduce_logaddexp_deltas_lazy(): a = Delta('a', Tensor(torch.randn(3, 2), OrderedDict(i=bint(3)))) b = Delta('b', Tensor(torch.randn(3), OrderedDict(i=bint(3)))) x = a + b assert isinstance(x, Delta) assert set(x.inputs) == {'a', 'b', 'i'} y = x.reduce(ops.logaddexp, 'i') # assert isinstance(y, Reduce) assert set(y.inputs) == {'a', 'b'} assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))
def moment_matching_reduce(self, op, reduced_vars): if not reduced_vars: return self if op is ops.logaddexp: if not all(reduced_vars.isdisjoint(d.inputs) for d in self.deltas): raise NotImplementedError( 'TODO handle moment_matching with Deltas') lazy_vars = frozenset().union( *(d.inputs for d in self.deltas)).intersection(reduced_vars) approx_vars = frozenset(k for k in reduced_vars - lazy_vars if self.inputs[k].dtype != 'real' if k in self.gaussian.inputs) exact_vars = reduced_vars - lazy_vars - approx_vars if exact_vars: return self.eager_reduce(op, exact_vars).reduce( op, approx_vars | lazy_vars) # Moment-matching approximation. assert approx_vars and not exact_vars discrete = self.discrete new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) num_elements = reduce(ops.mul, [ self.inputs[k].num_elements for k in approx_vars.difference(discrete.inputs) ], 1) if num_elements != 1: new_discrete -= math.log(num_elements) gaussian = self.gaussian int_inputs = OrderedDict((k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real') probs = (discrete - new_discrete).exp() old_loc = Tensor(gaussian.loc, int_inputs) new_loc = (probs * old_loc).reduce(ops.add, approx_vars) old_cov = Tensor(sym_inverse(gaussian.precision), int_inputs) diff = old_loc - new_loc outers = Tensor( diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs) new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) + (probs * outers).reduce(ops.add, approx_vars)) new_precision = Tensor(sym_inverse(new_cov.data), new_cov.inputs) new_inputs = new_loc.inputs.copy() new_inputs.update((k, d) for k, d in self.gaussian.inputs.items() if d.dtype == 'real') new_gaussian = Gaussian(new_loc.data, new_precision.data, new_inputs) result = Joint(self.deltas, new_discrete, new_gaussian) return result.reduce(ops.logaddexp, lazy_vars) return None # defer to default implementation
def test_getitem_number_0_inputs(): data = torch.randn((5, 4, 3, 2)) x = Tensor(data) assert_close(x[2], Tensor(data[2])) assert_close(x[:, 1], Tensor(data[:, 1])) assert_close(x[2, 1], Tensor(data[2, 1])) assert_close(x[2, :, 1], Tensor(data[2, :, 1])) assert_close(x[3, ...], Tensor(data[3, ...])) assert_close(x[3, 2, ...], Tensor(data[3, 2, ...])) assert_close(x[..., 1], Tensor(data[..., 1])) assert_close(x[..., 2, 1], Tensor(data[..., 2, 1])) assert_close(x[3, ..., 1], Tensor(data[3, ..., 1]))
def test_function_matmul(): @funsor.torch.function(reals(3, 4), reals(4, 5), reals(3, 5)) def matmul(x, y): return torch.matmul(x, y) check_funsor(matmul, {'x': reals(3, 4), 'y': reals(4, 5)}, reals(3, 5)) x = Tensor(torch.randn(3, 4)) y = Tensor(torch.randn(4, 5)) actual = matmul(x, y) expected_data = torch.matmul(x.data, y.data) check_funsor(actual, {}, reals(3, 5), expected_data)
def test_getitem_number_2_inputs(): data = torch.randn((3, 4, 5, 4, 3, 2)) inputs = OrderedDict([('i', bint(3)), ('j', bint(4))]) x = Tensor(data, inputs) assert_close(x[2], Tensor(data[:, :, 2], inputs)) assert_close(x[:, 1], Tensor(data[:, :, :, 1], inputs)) assert_close(x[2, 1], Tensor(data[:, :, 2, 1], inputs)) assert_close(x[2, :, 1], Tensor(data[:, :, 2, :, 1], inputs)) assert_close(x[3, ...], Tensor(data[:, :, 3, ...], inputs)) assert_close(x[3, 2, ...], Tensor(data[:, :, 3, 2, ...], inputs)) assert_close(x[..., 1], Tensor(data[..., 1], inputs)) assert_close(x[..., 2, 1], Tensor(data[..., 2, 1], inputs)) assert_close(x[3, ..., 1], Tensor(data[:, :, 3, ..., 1], inputs))