Esempio n. 1
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple(
            (k, v if isinstance(v, (Variable, Slice)) else materialize(v))
            for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Affine funsors are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple(
            (k, v) for k, v in subs
            if not isinstance(v, (Number, Tensor, Variable, Slice))
            and not (is_affine(v) and affine_inputs(v)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor, Slice))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        affine_subs = tuple((k, v) for k, v in subs
                            if is_affine(v) and affine_inputs(v)
                            and not isinstance(v, Variable))
        if var_subs:
            return self._eager_subs_var(
                var_subs, int_subs + real_subs + affine_subs + lazy_subs)
        if int_subs:
            return self._eager_subs_int(int_subs,
                                        real_subs + affine_subs + lazy_subs)
        if real_subs:
            return self._eager_subs_real(real_subs, affine_subs + lazy_subs)
        if affine_subs:
            return self._eager_subs_affine(affine_subs, lazy_subs)
        return reflect(Subs, self, lazy_subs)
Esempio n. 2
0
def test_gaussian_mixture_distribution(batch_inputs, event_inputs):
    num_samples = 100000
    sample_inputs = OrderedDict(particle=bint(num_samples))
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    int_inputs = OrderedDict(
        (k, d) for k, d in be_inputs.items() if d.dtype != 'real')
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    sampled_vars = frozenset(['f'])
    p = random_gaussian(be_inputs) + 0.5 * random_tensor(int_inputs)
    p_marginal = p.reduce(ops.logaddexp, 'e')
    assert isinstance(p_marginal, Tensor)

    q = p.sample(sampled_vars, sample_inputs)
    q_marginal = q.reduce(ops.logaddexp, 'e')
    q_marginal = materialize(q_marginal).reduce(ops.logaddexp, 'particle')
    assert isinstance(q_marginal, Tensor)
    q_marginal = q_marginal.align(tuple(p_marginal.inputs))
    assert_close(q_marginal, p_marginal, atol=0.1, rtol=None)
Esempio n. 3
0
def test_tensor_distribution(event_inputs, batch_inputs, test_grad):
    num_samples = 50000
    sample_inputs = OrderedDict(n=bint(num_samples))
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    sampled_vars = frozenset(event_inputs)
    p = random_tensor(be_inputs)
    p.data.requires_grad_(test_grad)

    q = p.sample(sampled_vars, sample_inputs)
    mq = materialize(q).reduce(ops.logaddexp, 'n')
    mq = mq.align(tuple(p.inputs))
    assert_close(mq, p, atol=0.1, rtol=None)

    if test_grad:
        _, (p_data, mq_data) = align_tensors(p, mq)
        assert p_data.shape == mq_data.shape
        probe = torch.randn(p_data.shape)
        expected = grad((p_data.exp() * probe).sum(), [p.data])[0]
        actual = grad((mq_data.exp() * probe).sum(), [p.data])[0]
        assert_close(actual, expected, atol=0.1, rtol=None)
Esempio n. 4
0
def eager_categorical(probs, value):
    value = materialize(value)
    return Categorical.eager_log_prob(probs=probs, value=value)
Esempio n. 5
0
def _scatter(src, res, subs):
    # inverse of advanced indexing
    # TODO check types of subs, in case some logic from eager_subs was accidentally left out?

    # use advanced indexing logic copied from Tensor.eager_subs:

    # materialize after checking for renaming case
    subs = OrderedDict((k, materialize(v)) for k, v in subs)

    # Compute result shapes.
    inputs = OrderedDict()
    for k, domain in res.inputs.items():
        inputs[k] = domain

    # Construct a dict with each input's positional dim,
    # counting from the right so as to support broadcasting.
    total_size = len(inputs) + len(res.output.shape)  # Assumes only scalar indices.
    new_dims = {}
    for k, domain in inputs.items():
        assert not domain.shape
        new_dims[k] = len(new_dims) - total_size

    # Use advanced indexing to construct a simultaneous substitution.
    index = []
    for k, domain in res.inputs.items():
        if k in subs:
            v = subs.get(k)
            if isinstance(v, Number):
                index.append(int(v.data))
            else:
                # Permute and expand v.data to end up at new_dims.
                assert isinstance(v, Tensor)
                v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs))
                assert isinstance(v, Tensor)
                v_shape = [1] * total_size
                for k2, size in zip(v.inputs, v.data.shape):
                    v_shape[new_dims[k2]] = size
                index.append(v.data.reshape(tuple(v_shape)))
        else:
            # Construct a [:] slice for this preserved input.
            offset_from_right = -1 - new_dims[k]
            index.append(torch.arange(domain.dtype).reshape(
                (-1,) + (1,) * offset_from_right))

    # Construct a [:] slice for the output.
    for i, size in enumerate(res.output.shape):
        offset_from_right = len(res.output.shape) - i - 1
        index.append(torch.arange(size).reshape(
            (-1,) + (1,) * offset_from_right))

    # the only difference from Tensor.eager_subs is here:
    # instead of indexing the rhs (lhs = rhs[index]), we index the lhs (lhs[index] = rhs)

    # unsqueeze to make broadcasting work
    src_inputs, src_data = src.inputs.copy(), src.data
    for k, v in res.inputs.items():
        if k not in src.inputs and isinstance(subs[k], Number):
            src_inputs[k] = bint(1)
            src_data = src_data.unsqueeze(-1 - len(src.output.shape))
    src = Tensor(src_data, src_inputs, src.output.dtype).align(tuple(res.inputs.keys()))

    data = res.data
    data[tuple(index)] = src.data
    return Tensor(data, inputs, res.dtype)
Esempio n. 6
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple((k, materialize(to_funsor(v, self.inputs[k])))
                     for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Variables are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple((k, v) for k, v in subs
                          if not isinstance(v, (Number, Tensor, Variable)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        if not (var_subs or int_subs or real_subs):
            return reflect(Subs, self, lazy_subs)

        # First perform any variable substitutions.
        if var_subs:
            rename = {k: v.name for k, v in var_subs}
            inputs = OrderedDict(
                (rename.get(k, k), d) for k, d in self.inputs.items())
            if len(inputs) != len(self.inputs):
                raise ValueError("Variable substitution name conflict")
            var_result = Gaussian(self.loc, self.precision, inputs)
            return Subs(var_result, int_subs + real_subs + lazy_subs)

        # Next perform any integer substitution, i.e. slicing into a batch.
        if int_subs:
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            real_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
            tensors = [self.loc, self.precision]
            funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
            inputs = funsors[0].inputs.copy()
            inputs.update(real_inputs)
            int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
            return Subs(int_result, real_subs + lazy_subs)

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        real_subs = OrderedDict(subs)
        assert real_subs and not int_subs
        if all(k in real_subs for k, d in self.inputs.items()
               if d.dtype == 'real'):
            # Broadcast all component tensors.
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            tensors = [
                Tensor(self.loc, int_inputs),
                Tensor(self.precision, int_inputs)
            ]
            tensors.extend(real_subs.values())
            inputs, tensors = align_tensors(*tensors)
            batch_dim = tensors[0].dim() - 1
            batch_shape = broadcast_shape(*(x.shape[:batch_dim]
                                            for x in tensors))
            (loc, precision), values = tensors[:2], tensors[2:]

            # Form the concatenated value.
            offsets, event_size = _compute_offsets(self.inputs)
            value = BlockVector(batch_shape + (event_size, ))
            for k, value_k in zip(real_subs, values):
                offset = offsets[k]
                value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1, ))
                if not torch._C._get_tracing_state():
                    assert value_k.size(-1) == self.inputs[k].num_elements
                value_k = value_k.expand(batch_shape + value_k.shape[-1:])
                value[...,
                      offset:offset + self.inputs[k].num_elements] = value_k
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = -0.5 * _vmv(precision, value - loc)
            result = Tensor(result, inputs)
            assert result.output == reals()
            return Subs(result, lazy_subs)

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353.
        # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
        raise NotImplementedError(
            'TODO implement partial substitution of real variables')
Esempio n. 7
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple(
            (k, v if isinstance(v, (Variable, Slice)) else materialize(v))
            for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Variables are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple(
            (k, v) for k, v in subs
            if not isinstance(v, (Number, Tensor, Variable, Slice)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor, Slice))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        if not (var_subs or int_subs or real_subs):
            return reflect(Subs, self, lazy_subs)

        # First perform any variable substitutions.
        if var_subs:
            rename = {k: v.name for k, v in var_subs}
            inputs = OrderedDict(
                (rename.get(k, k), d) for k, d in self.inputs.items())
            if len(inputs) != len(self.inputs):
                raise ValueError("Variable substitution name conflict")
            var_result = Gaussian(self.info_vec, self.precision, inputs)
            return Subs(var_result, int_subs + real_subs + lazy_subs)

        # Next perform any integer substitution, i.e. slicing into a batch.
        if int_subs:
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            real_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
            tensors = [self.info_vec, self.precision]
            funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
            inputs = funsors[0].inputs.copy()
            inputs.update(real_inputs)
            int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
            return Subs(int_result, real_subs + lazy_subs)

        # Broadcast all component tensors.
        real_subs = OrderedDict(subs)
        assert real_subs and not int_subs
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(real_subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = tensors[0].dim() - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(real_subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not torch._C._get_tracing_state():
                assert value.size(-1) == self.inputs[k].num_elements
            values[k] = value.expand(batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in real_subs for k, d in self.inputs.items()
               if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == reals()
            return Subs(result, lazy_subs)

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in real_subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_ab = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_bb = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in b
        ],
                            dim=-2)
        info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a],
                           dim=-1)
        info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b],
                           dim=-1)
        value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1)
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = prec_aa.expand(info_vec.shape + (-1, ))
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in real_subs:
                inputs[k] = d
        return Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)