示例#1
0
def eager_cat_homogeneous(name, part_name, *parts):
    assert parts
    output = parts[0].output
    inputs = OrderedDict([(part_name, None)])
    for part in parts:
        assert part.output == output
        assert part_name in part.inputs
        inputs.update(part.inputs)

    int_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype != "real")
    real_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype == "real")
    inputs = int_inputs.copy()
    inputs.update(real_inputs)
    discretes = []
    info_vecs = []
    precisions = []
    for part in parts:
        inputs[part_name] = part.inputs[part_name]
        int_inputs[part_name] = inputs[part_name]
        shape = tuple(d.size for d in int_inputs.values())
        if isinstance(part, Gaussian):
            discrete = None
            gaussian = part
        elif issubclass(type(part), GaussianMixture
                        ):  # TODO figure out why isinstance isn't working
            discrete, gaussian = part.terms[0], part.terms[1]
            discrete = ops.expand(align_tensor(int_inputs, discrete), shape)
        else:
            raise NotImplementedError("TODO")
        discretes.append(discrete)
        info_vec, precision = align_gaussian(inputs, gaussian)
        info_vecs.append(ops.expand(info_vec, shape + (-1, )))
        precisions.append(ops.expand(precision, shape + (-1, -1)))
    if part_name != name:
        del inputs[part_name]
        del int_inputs[part_name]

    dim = 0
    info_vec = ops.cat(dim, *info_vecs)
    precision = ops.cat(dim, *precisions)
    inputs[name] = Bint[info_vec.shape[dim]]
    int_inputs[name] = inputs[name]
    result = Gaussian(info_vec, precision, inputs)
    if any(d is not None for d in discretes):
        for i, d in enumerate(discretes):
            if d is None:
                discretes[i] = ops.new_zeros(info_vecs[i],
                                             info_vecs[i].shape[:-1])
        discrete = ops.cat(dim, *discretes)
        result = result + Tensor(discrete, int_inputs)
    return result
示例#2
0
def broadcast_all(*values, **kwargs):
    """
    Packed broadcasting of multiple tensors.
    """
    inputs = kwargs.get('inputs')
    dims = kwargs.get('dims')
    sizes = {
        dim: size
        for value, old_dims in zip(values, inputs)
        for dim, size in zip(old_dims, value.shape)
    }
    if dims is None:
        dims = ''.join(sorted(sizes))
    else:
        assert set(dims) == set(sizes)
    shape = tuple(sizes[dim] for dim in dims)
    values = list(values)
    for i, (x, old_dims) in enumerate(zip(values, inputs)):
        if old_dims != dims:
            x = ops.permute(
                x,
                tuple(old_dims.index(dim) for dim in dims if dim in old_dims))
            x = x.reshape(
                tuple(sizes[dim] if dim in old_dims else 1 for dim in dims))
            x = ops.expand(x, shape)
            assert len(x.shape) == len(dims)
            values[i] = x
    return tuple(values)
示例#3
0
def eager_multinomial(total_count, probs, value):
    # Multinomial.log_prob() supports inhomogeneous total_count only by
    # avoiding passing total_count to the constructor.
    inputs, (total_count, probs,
             value) = align_tensors(total_count, probs, value)
    shape = broadcast_shape(total_count.shape + (1, ), probs.shape,
                            value.shape)
    probs = Tensor(ops.expand(probs, shape), inputs)
    value = Tensor(ops.expand(value, shape), inputs)
    if get_backend() == "torch":
        total_count = Number(
            ops.amax(total_count,
                     None).item())  # Used by distributions validation code.
    else:
        total_count = Tensor(ops.expand(total_count, shape[:-1]), inputs)
    backend_dist = import_module(
        BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
    return backend_dist.Multinomial.eager_log_prob(total_count, probs,
                                                   value)  # noqa: F821
示例#4
0
def eager_lambda(var, expr):
    inputs = expr.inputs.copy()
    if var.name in inputs:
        inputs.pop(var.name)
        inputs[var.name] = var.output
        data = align_tensor(inputs, expr)
        inputs.pop(var.name)
    else:
        data = expr.data
        shape = data.shape
        dim = len(shape) - len(expr.output.shape)
        data = data.reshape(shape[:dim] + (1, ) + shape[dim:])
        data = ops.expand(data, shape[:dim] + (var.dtype, ) + shape[dim:])
    return Tensor(data, inputs, expr.dtype)
示例#5
0
def eager_stack_homogeneous(name, *parts):
    assert parts
    output = parts[0].output
    part_inputs = OrderedDict()
    for part in parts:
        assert part.output == output
        assert name not in part.inputs
        part_inputs.update(part.inputs)

    shape = tuple(d.size for d in part_inputs.values()) + output.shape
    data = ops.stack(
        0, *[
            ops.expand(align_tensor(part_inputs, part), shape)
            for part in parts
        ])
    inputs = OrderedDict([(name, Bint[len(parts)])])
    inputs.update(part_inputs)
    return Tensor(data, inputs, dtype=output.dtype)
示例#6
0
def align_tensor(new_inputs, x, expand=False):
    r"""
    Permute and add dims to a tensor to match desired ``new_inputs``.

    :param OrderedDict new_inputs: A target set of inputs.
    :param funsor.terms.Funsor x: A :class:`Tensor` or
        :class:`~funsor.terms.Number` .
    :param bool expand: If False (default), set result size to 1 for any input
        of ``x`` not in ``new_inputs``; if True expand to ``new_inputs`` size.
    :return: a number or :class:`torch.Tensor` or :class:`np.ndarray` that can be broadcast to other
        tensors with inputs ``new_inputs``.
    :rtype: int or float or torch.Tensor or np.ndarray
    """
    assert isinstance(new_inputs, OrderedDict)
    assert isinstance(x, (Number, Tensor))
    assert all(isinstance(d.dtype, int) for d in x.inputs.values())

    data = x.data
    if isinstance(x, Number):
        return data

    old_inputs = x.inputs
    if old_inputs == new_inputs:
        return data

    # Permute squashed input dims.
    x_keys = tuple(old_inputs)
    data = ops.permute(
        data,
        tuple(x_keys.index(k) for k in new_inputs if k in old_inputs) +
        tuple(range(len(old_inputs), len(data.shape))))

    # Unsquash multivariate input dims by filling in ones.
    data = data.reshape(
        tuple(old_inputs[k].dtype if k in old_inputs else 1
              for k in new_inputs) + x.output.shape)

    # Optionally expand new dims.
    if expand:
        data = ops.expand(
            data,
            tuple(d.dtype for d in new_inputs.values()) + x.output.shape)
    return data
示例#7
0
def test_batched_einsum(equation, batch1, batch2):
    inputs, output = equation.split('->')
    inputs = inputs.split(',')

    sizes = dict(a=2, b=3, c=4, i=5, j=6)
    batch1 = OrderedDict([(k, bint(sizes[k])) for k in batch1])
    batch2 = OrderedDict([(k, bint(sizes[k])) for k in batch2])
    funsors = [
        random_tensor(batch, reals(*(sizes[d] for d in dims)))
        for batch, dims in zip([batch1, batch2], inputs)
    ]
    actual = Einsum(equation, tuple(funsors))

    _equation = ','.join('...' + i for i in inputs) + '->...' + output
    inputs, tensors = align_tensors(*funsors)
    batch = tuple(v.size for v in inputs.values())
    tensors = [
        ops.expand(x, batch + f.shape) for (x, f) in zip(tensors, funsors)
    ]
    expected = Tensor(ops.einsum(_equation, *tensors), inputs)
    assert_close(actual, expected, atol=1e-5, rtol=None)
示例#8
0
def eager_cat_homogeneous(name, part_name, *parts):
    assert parts
    output = parts[0].output
    inputs = OrderedDict([(part_name, None)])
    for part in parts:
        assert part.output == output
        assert part_name in part.inputs
        inputs.update(part.inputs)

    tensors = []
    for part in parts:
        inputs[part_name] = part.inputs[part_name]
        shape = tuple(d.size for d in inputs.values()) + output.shape
        tensors.append(ops.expand(align_tensor(inputs, part), shape))
    del inputs[part_name]

    dim = 0
    tensor = ops.cat(dim, *tensors)
    inputs = OrderedDict([(name, Bint[tensor.shape[dim]])] +
                         list(inputs.items()))
    return Tensor(tensor, inputs, dtype=output.dtype)
示例#9
0
def test_ops_expand(expand_shape):
    x = randn((3, 2))
    actual = ops.expand(x, expand_shape)
    assert actual.shape == (4, 3, 2)
示例#10
0
    def _eager_subs_real(self, subs, remaining_subs):
        # Broadcast all component tensors.
        subs = OrderedDict(subs)
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = len(tensors[0].shape) - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not get_tracing_state():
                assert value.shape[-1] == self.inputs[k].num_elements
            values[k] = ops.expand(value, batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == Real
            return Subs(result, remaining_subs) if remaining_subs else result

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in a])
                for k1, i1 in slices if k1 in a
            ])
        prec_ab = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in a
            ])
        prec_bb = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in b
            ])
        info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a])
        info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b])
        value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b])
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:])
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in subs:
                inputs[k] = d
        result = Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
示例#11
0
def dummy_numeric_array(domain):
    value = 0.1 if domain.dtype == 'real' else 1
    return ops.expand(numeric_array(value),
                      domain.shape) if domain.shape else value