Example #1
0
    def __init__(self, red_op, bin_op, reduced_vars, terms):
        terms = (terms, ) if isinstance(terms, Funsor) else terms
        assert isinstance(red_op, AssociativeOp)
        assert isinstance(bin_op, AssociativeOp)
        assert all(isinstance(v, Funsor) for v in terms)
        assert isinstance(reduced_vars, frozenset)
        assert all(isinstance(v, str) for v in reduced_vars)
        assert isinstance(terms, tuple) and len(terms) > 0

        assert not (isinstance(red_op, NullOp) and isinstance(bin_op, NullOp))
        if isinstance(red_op, NullOp):
            assert not reduced_vars
        elif isinstance(bin_op, NullOp):
            assert len(terms) == 1
        else:
            assert reduced_vars and len(terms) > 1
            assert (red_op, bin_op) in DISTRIBUTIVE_OPS

        inputs = OrderedDict()
        for v in terms:
            inputs.update(
                (k, d) for k, d in v.inputs.items() if k not in reduced_vars)

        if bin_op is nullop:
            output = terms[0].output
        else:
            output = reduce(lambda lhs, rhs: find_domain(bin_op, lhs, rhs),
                            [v.output for v in reversed(terms)])
        fresh = frozenset()
        bound = reduced_vars
        super(Contraction, self).__init__(inputs, output, fresh, bound)
        self.red_op = red_op
        self.bin_op = bin_op
        self.terms = terms
        self.reduced_vars = reduced_vars
Example #2
0
 def __init__(self, op, arg):
     assert callable(op)
     assert isinstance(arg, Funsor)
     output = find_domain(op, arg.output)
     super(Unary, self).__init__(arg.inputs, output)
     self.op = op
     self.arg = arg
Example #3
0
def eager_binary_tensor_tensor(op, lhs, rhs):
    # Compute inputs and outputs.
    dtype = find_domain(op, lhs.output, rhs.output).dtype
    if lhs.inputs == rhs.inputs:
        inputs = lhs.inputs
        lhs_data, rhs_data = lhs.data, rhs.data
    else:
        inputs, (lhs_data, rhs_data) = align_tensors(lhs, rhs)
    if len(lhs.shape) == 1:
        lhs_data = lhs_data.unsqueeze(-2)
    if len(rhs.shape) == 1:
        rhs_data = rhs_data.unsqueeze(-1)

    # Reshape to support broadcasting of output shape.
    if inputs:
        lhs_dim = max(2, len(lhs.shape))
        rhs_dim = max(2, len(rhs.shape))
        if lhs_dim < rhs_dim:
            cut = lhs_data.dim() - lhs_dim
            shape = lhs_data.shape
            shape = shape[:cut] + (1,) * (rhs_dim - lhs_dim) + shape[cut:]
            lhs_data = lhs_data.reshape(shape)
        elif rhs_dim < lhs_dim:
            cut = rhs_data.dim() - rhs_dim
            shape = rhs_data.shape
            shape = shape[:cut] + (1,) * (lhs_dim - rhs_dim) + shape[cut:]
            rhs_data = rhs_data.reshape(shape)

    data = op(lhs_data, rhs_data)
    if len(lhs.shape) == 1:
        data = data.squeeze(-2)
    if len(rhs.shape) == 1:
        data = data.squeeze(-1)
    return Tensor(data, inputs, dtype)
Example #4
0
 def eager_unary(self, op):
     dtype = find_domain(op, self.output).dtype
     if op in REDUCE_OP_TO_NUMERIC:
         batch_dim = len(self.data.shape) - len(self.output.shape)
         data = self.data.reshape(self.data.shape[:batch_dim] + (-1, ))
         data = REDUCE_OP_TO_NUMERIC[op](data, -1)
         return Tensor(data, self.inputs, dtype)
     return Tensor(op(self.data), self.inputs, dtype)
Example #5
0
 def eager_unary(self, op):
     dtype = find_domain(op, self.output).dtype
     if op in REDUCE_OP_TO_TORCH:
         batch_dim = len(self.data.shape) - len(self.output.shape)
         data = self.data.reshape(self.data.shape[:batch_dim] + (-1,))
         data = REDUCE_OP_TO_TORCH[op](data, -1)
         if op is ops.min or op is ops.max:
             data = data[0]
         return Tensor(data, self.inputs, dtype)
     return Tensor(op(self.data), self.inputs, dtype)
Example #6
0
 def __init__(self, op, lhs, rhs):
     assert callable(op)
     assert isinstance(lhs, Funsor)
     assert isinstance(rhs, Funsor)
     inputs = lhs.inputs.copy()
     inputs.update(rhs.inputs)
     output = find_domain(op, lhs.output, rhs.output)
     super(Binary, self).__init__(inputs, output)
     self.op = op
     self.lhs = lhs
     self.rhs = rhs
Example #7
0
def test_matmul(inputs1, inputs2, output_shape1, output_shape2):
    sizes = {'a': 6, 'b': 7, 'c': 8}
    inputs1 = OrderedDict((k, bint(sizes[k])) for k in inputs1)
    inputs2 = OrderedDict((k, bint(sizes[k])) for k in inputs2)
    x1 = random_tensor(inputs1, reals(*output_shape1))
    x2 = random_tensor(inputs1, reals(*output_shape2))

    actual = x1 @ x2
    assert actual.output == find_domain(ops.matmul, x1.output, x2.output)

    block = {'a': 1, 'b': 2, 'c': 3}
    actual_block = actual(**block)
    expected_block = Tensor(x1(**block).data @ x2(**block).data)
    assert_close(actual_block, expected_block, atol=1e-5, rtol=1e-5)
Example #8
0
def test_binary_broadcast(inputs1, inputs2, output_shape1, output_shape2):
    sizes = {'a': 4, 'b': 5, 'c': 6}
    inputs1 = OrderedDict((k, bint(sizes[k])) for k in inputs1)
    inputs2 = OrderedDict((k, bint(sizes[k])) for k in inputs2)
    x1 = random_tensor(inputs1, reals(*output_shape1))
    x2 = random_tensor(inputs1, reals(*output_shape2))

    actual = x1 + x2
    assert actual.output == find_domain(ops.add, x1.output, x2.output)

    block = {'a': 1, 'b': 2, 'c': 3}
    actual_block = actual(**block)
    expected_block = Tensor(x1(**block).data + x2(**block).data)
    assert_close(actual_block, expected_block)
Example #9
0
    def __init__(self, const, coeffs):
        assert isinstance(const, (Number, Tensor))
        assert not any(d.dtype == "real" for d in const.inputs.values())
        assert isinstance(coeffs, tuple)
        inputs = const.inputs.copy()
        output = const.output
        assert output.dtype == "real"
        for var, coeff in coeffs:
            assert isinstance(var, Variable)
            assert isinstance(coeff, (Number, Tensor))
            assert not any(d.dtype == "real" for d in coeff.inputs.values())
            inputs.update(coeff.inputs)
            inputs.update(var.inputs)
            output = find_domain(
                ops.add, output, find_domain(ops.mul, var.output,
                                             coeff.output))
            assert var.dtype == "real"
            assert coeff.dtype == "real"
            assert output.dtype == "real"

        super(Affine, self).__init__(inputs, output)
        self.coeffs = OrderedDict(coeffs)
        self.const = const
Example #10
0
    def __init__(self, op, operands):
        assert callable(op)
        assert isinstance(operands, tuple)
        assert all(isinstance(operand, Funsor) for operand in operands)
        inputs = collections.OrderedDict()
        for operand in operands:
            inputs.update(operand.inputs)

        output = reduce(lambda lhs, rhs: find_domain(op, lhs, rhs),
                        [operand.output for operand in reversed(operands)])

        super(Finitary, self).__init__(inputs, output)
        self.op = op
        self.operands = operands
Example #11
0
def eager_contract(sum_op, prod_op, lhs, rhs, reduced_vars):
    if (sum_op, prod_op) == (ops.add, ops.mul):
        backend = "torch"
    elif (sum_op, prod_op) == (ops.logaddexp, ops.add):
        backend = "pyro.ops.einsum.torch_log"
    else:
        return prod_op(lhs, rhs).reduce(sum_op, reduced_vars)

    inputs = OrderedDict((k, d) for t in (lhs, rhs)
                         for k, d in t.inputs.items() if k not in reduced_vars)

    data = opt_einsum.contract(lhs.data,
                               list(lhs.inputs),
                               rhs.data,
                               list(rhs.inputs),
                               list(inputs),
                               backend=backend)
    dtype = find_domain(prod_op, lhs.output, rhs.output).dtype
    return Tensor(data, inputs, dtype)
Example #12
0
def eager_binary_array_array(op, lhs, rhs):
    # Compute inputs and outputs.
    dtype = find_domain(op, lhs.output, rhs.output).dtype
    if lhs.inputs == rhs.inputs:
        inputs = lhs.inputs
        lhs_data, rhs_data = lhs.data, rhs.data
    else:
        inputs, (lhs_data, rhs_data) = align_arrays(lhs, rhs)

    if op is ops.getitem:
        # getitem has special shape semantics.
        if rhs.output.shape:
            raise NotImplementedError('TODO support vector indexing')
        assert lhs.output.shape == (rhs.dtype, )
        index = [
            np.arange(size).reshape((-1, ) + (1, ) * (lhs_data.ndim - pos - 2))
            for pos, size in enumerate(lhs_data.shape)
        ]
        index[-1] = rhs_data
        data = lhs_data[tuple(index)]
    else:
        data = op(lhs_data, rhs_data)

    return Array(data, inputs, dtype)
Example #13
0
def eager_binary_number_number(op, lhs, rhs):
    data = op(lhs.data, rhs.data)
    output = find_domain(op, lhs.output, rhs.output)
    dtype = output.dtype
    return Number(data, dtype)
Example #14
0
 def eager_unary(self, op):
     dtype = find_domain(op, self.output).dtype
     return Number(op(self.data), dtype)