示例#1
0
def eager_getitem_tensor_tensor(op, lhs, rhs):
    assert op.offset < len(lhs.output.shape)
    assert rhs.output == Bint[lhs.output.shape[op.offset]]

    # Compute inputs and outputs.
    if lhs.inputs == rhs.inputs:
        inputs, lhs_data, rhs_data = lhs.inputs, lhs.data, rhs.data
    else:
        inputs, (lhs_data, rhs_data) = align_tensors(lhs, rhs)
    if len(lhs.output.shape) > 1:
        rhs_data = rhs_data.reshape(rhs_data.shape + (1, ) *
                                    (len(lhs.output.shape) - 1))

    # Perform advanced indexing.
    lhs_data_dim = len(lhs_data.shape)
    target_dim = lhs_data_dim - len(lhs.output.shape) + op.offset
    index = [None] * lhs_data_dim
    for i in range(target_dim):
        index[i] = ops.new_arange(
            lhs_data,
            lhs_data.shape[i]).reshape((-1, ) + (1, ) * (lhs_data_dim - i - 2))
    index[target_dim] = rhs_data
    for i in range(1 + target_dim, lhs_data_dim):
        index[i] = ops.new_arange(
            lhs_data,
            lhs_data.shape[i]).reshape((-1, ) + (1, ) * (lhs_data_dim - i - 1))
    data = lhs_data[tuple(index)]
    return Tensor(data, inputs, lhs.dtype)
示例#2
0
    def new_arange(self, name, *args, **kwargs):
        """
        Helper to create a named :func:`torch.arange` or :func:`np.arange` funsor.
        In some cases this can be replaced by a symbolic
        :class:`~funsor.terms.Slice` .

        :param str name: A variable name.
        :param int start:
        :param int stop:
        :param int step: Three args following :py:class:`slice` semantics.
        :param int dtype: An optional bounded integer type of this slice.
        :rtype: Tensor
        """
        start = 0
        step = 1
        dtype = None
        if len(args) == 1:
            stop = args[0]
            dtype = kwargs.pop("dtype", stop)
        elif len(args) == 2:
            start, stop = args
            dtype = kwargs.pop("dtype", stop)
        elif len(args) == 3:
            start, stop, step = args
            dtype = kwargs.pop("dtype", stop)
        elif len(args) == 4:
            start, stop, step, dtype = args
        else:
            raise ValueError
        if step <= 0:
            raise ValueError
        stop = min(dtype, max(start, stop))
        data = ops.new_arange(self.data, start, stop, step)
        inputs = OrderedDict([(name, Bint[len(data)])])
        return Tensor(data, inputs, dtype=dtype)
示例#3
0
def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim):
    time = Variable("time", bint(num_steps))
    bias = Variable("bias", reals(num_sensors, dim))
    bias_dist = random_gaussian(
        OrderedDict([
            ("bias", reals(num_sensors, dim)),
        ]))
    trans = random_gaussian(
        OrderedDict([
            ("time", bint(num_steps)),
            ("x_prev", reals(dim)),
            ("x_curr", reals(dim)),
        ]))
    obs = random_gaussian(
        OrderedDict([
            ("time", bint(num_steps)),
            ("x_curr", reals(dim)),
            ("bias", reals(dim)),
        ]))

    # Each time step only a single sensor observes x,
    # and each sensor has a different bias.
    sensor_id = Tensor(ops.new_arange(get_default_prototype(), num_steps) % 2,
                       OrderedDict(time=bint(num_steps)),
                       dtype=2)
    with interpretation(eager_or_die):
        factor = trans + obs(bias=bias[sensor_id]) + bias_dist
    assert set(factor.inputs) == {"time", "bias", "x_prev", "x_curr"}

    result = sequential_sum_product(ops.logaddexp, ops.add, factor, time,
                                    {"x_prev": "x_curr"})
    assert set(result.inputs) == {"bias", "x_prev", "x_curr"}
示例#4
0
    def eager_reduce(self, op, reduced_vars):
        if op is ops.logaddexp:
            # Marginalize out real variables, but keep mixtures lazy.
            assert all(v in self.inputs for v in reduced_vars)
            real_vars = frozenset(k for k, d in self.inputs.items()
                                  if d.dtype == "real")
            reduced_reals = reduced_vars & real_vars
            reduced_ints = reduced_vars - real_vars
            if not reduced_reals:
                return None  # defer to default implementation

            inputs = OrderedDict((k, d) for k, d in self.inputs.items()
                                 if k not in reduced_reals)
            if reduced_reals == real_vars:
                result = self.log_normalizer
            else:
                int_inputs = OrderedDict(
                    (k, v) for k, v in inputs.items() if v.dtype != 'real')
                offsets, _ = _compute_offsets(self.inputs)
                a = []
                b = []
                for key, domain in self.inputs.items():
                    if domain.dtype == 'real':
                        block = ops.new_arange(
                            self.info_vec, offsets[key],
                            offsets[key] + domain.num_elements, 1)
                        (b if key in reduced_vars else a).append(block)
                a = ops.cat(-1, *a)
                b = ops.cat(-1, *b)
                prec_aa = self.precision[..., a[..., None], a]
                prec_ba = self.precision[..., b[..., None], a]
                prec_bb = self.precision[..., b[..., None], b]
                prec_b = ops.cholesky(prec_bb)
                prec_a = ops.triangular_solve(prec_ba, prec_b)
                prec_at = ops.transpose(prec_a, -1, -2)
                precision = prec_aa - ops.matmul(prec_at, prec_a)

                info_a = self.info_vec[..., a]
                info_b = self.info_vec[..., b]
                b_tmp = ops.triangular_solve(info_b[..., None], prec_b)
                info_vec = info_a - ops.matmul(prec_at, b_tmp)[..., 0]

                log_prob = Tensor(
                    0.5 * len(b) * math.log(2 * math.pi) -
                    _log_det_tri(prec_b) + 0.5 * (b_tmp[..., 0]**2).sum(-1),
                    int_inputs)
                result = log_prob + Gaussian(info_vec, precision, inputs)

            return result.reduce(ops.logaddexp, reduced_ints)

        elif op is ops.add:
            for v in reduced_vars:
                if self.inputs[v].dtype == 'real':
                    raise ValueError(
                        "Cannot sum along a real dimension: {}".format(
                            repr(v)))

            # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian().
            old_ints = OrderedDict(
                (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
            new_ints = OrderedDict(
                (k, v) for k, v in old_ints.items() if k not in reduced_vars)
            inputs = OrderedDict((k, v) for k, v in self.inputs.items()
                                 if k not in reduced_vars)

            info_vec = Tensor(self.info_vec,
                              old_ints).reduce(ops.add, reduced_vars)
            precision = Tensor(self.precision,
                               old_ints).reduce(ops.add, reduced_vars)
            assert info_vec.inputs == new_ints
            assert precision.inputs == new_ints
            return Gaussian(info_vec.data, precision.data, inputs)

        return None  # defer to default implementation
示例#5
0
    def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
        assert self.output == Real
        sampled_vars = sampled_vars.intersection(self.inputs)
        if not sampled_vars:
            return self

        # Partition inputs into sample_inputs + batch_inputs + event_inputs.
        sample_inputs = OrderedDict(
            (k, d) for k, d in sample_inputs.items() if k not in self.inputs)
        sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
        batch_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if k not in sampled_vars)
        event_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if k in sampled_vars)
        be_inputs = batch_inputs.copy()
        be_inputs.update(event_inputs)
        sb_inputs = sample_inputs.copy()
        sb_inputs.update(batch_inputs)

        # Sample all variables in a single Categorical call.
        logits = align_tensor(be_inputs, self)
        batch_shape = logits.shape[:len(batch_inputs)]
        flat_logits = logits.reshape(batch_shape + (-1, ))
        sample_shape = tuple(d.dtype for d in sample_inputs.values())

        backend = get_backend()
        if backend != "numpy":
            from importlib import import_module
            dist = import_module(
                funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend])
            sample_args = (sample_shape, ) if rng_key is None else (
                rng_key, sample_shape)
            flat_sample = dist.CategoricalLogits.dist_class(
                logits=flat_logits).sample(*sample_args)
        else:  # default numpy backend
            assert backend == "numpy"
            shape = sample_shape + flat_logits.shape[:-1]
            logit_max = np.amax(flat_logits, -1, keepdims=True)
            probs = np.exp(flat_logits - logit_max)
            probs = probs / np.sum(probs, -1, keepdims=True)
            s = np.cumsum(probs, -1)
            r = np.random.rand(*shape)
            flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1)

        assert flat_sample.shape == sample_shape + batch_shape
        results = []
        mod_sample = flat_sample
        for name, domain in reversed(list(event_inputs.items())):
            size = domain.dtype
            point = Tensor(mod_sample % size, sb_inputs, size)
            mod_sample = mod_sample // size
            results.append(Delta(name, point))

        # Account for the log normalizer factor.
        # Derivation: Let f be a nonnormalized distribution (a funsor), and
        #   consider operations in linear space (source code is in log space).
        #   Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|.
        #                              f(x0) / |f|      # dice numerator
        #   Let g = delta(x=x0) |f| -----------------
        #                           detach(f(x0)/|f|)   # dice denominator
        #                       |detach(f)| f(x0)
        #         = delta(x=x0) -----------------  be a dice approximation of f.
        #                         detach(f(x0))
        #   Then g is an unbiased estimator of f in value and all derivatives.
        #   In the special case f = detach(f), we can simplify to
        #       g = delta(x=x0) |f|.
        if (backend == "torch"
                and flat_logits.requires_grad) or backend == "jax":
            # Apply a dice factor to preserve differentiability.
            index = [
                ops.new_arange(self.data,
                               n).reshape((n, ) + (1, ) *
                                          (len(flat_logits.shape) - i - 2))
                for i, n in enumerate(flat_logits.shape[:-1])
            ]
            index.append(flat_sample)
            log_prob = flat_logits[tuple(index)]
            assert log_prob.shape == flat_sample.shape
            results.append(
                Tensor(
                    ops.logsumexp(ops.detach(flat_logits), -1) +
                    (log_prob - ops.detach(log_prob)), sb_inputs))
        else:
            # This is the special case f = detach(f).
            results.append(Tensor(ops.logsumexp(flat_logits, -1),
                                  batch_inputs))

        return reduce(ops.add, results)
示例#6
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = OrderedDict((k, to_funsor(v, self.inputs[k])) for k, v in subs
                           if k in self.inputs)
        if not subs:
            return self

        # Handle diagonal variable substitution
        var_counts = Counter(v for v in subs.values()
                             if isinstance(v, Variable))
        subs = OrderedDict((k, self.materialize(v) if var_counts[v] > 1 else v)
                           for k, v in subs.items())

        # Handle renaming to enable cons hashing, and
        # handle slicing to avoid copying data.
        if any(isinstance(v, (Variable, Slice)) for v in subs.values()):
            slices = None
            inputs = OrderedDict()
            for i, (k, d) in enumerate(self.inputs.items()):
                if k in subs:
                    v = subs[k]
                    if isinstance(v, Variable):
                        del subs[k]
                        k = v.name
                    elif isinstance(v, Slice):
                        del subs[k]
                        k = v.name
                        d = v.inputs[v.name]
                        if slices is None:
                            slices = [slice(None)] * len(self.data.shape)
                        slices[i] = v.slice
                inputs[k] = d
            data = self.data[tuple(slices)] if slices else self.data
            result = Tensor(data, inputs, self.dtype)
            return result.eager_subs(tuple(subs.items()))

        # materialize after checking for renaming case
        subs = OrderedDict((k, self.materialize(v)) for k, v in subs.items())

        # Compute result shapes.
        inputs = OrderedDict()
        for k, domain in self.inputs.items():
            if k in subs:
                inputs.update(subs[k].inputs)
            else:
                inputs[k] = domain

        # Construct a dict with each input's positional dim,
        # counting from the right so as to support broadcasting.
        total_size = len(inputs) + len(
            self.output.shape)  # Assumes only scalar indices.
        new_dims = {}
        for k, domain in inputs.items():
            assert not domain.shape
            new_dims[k] = len(new_dims) - total_size

        # Use advanced indexing to construct a simultaneous substitution.
        index = []
        for k, domain in self.inputs.items():
            if k in subs:
                v = subs.get(k)
                if isinstance(v, Number):
                    index.append(int(v.data))
                else:
                    # Permute and expand v.data to end up at new_dims.
                    assert isinstance(v, Tensor)
                    v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs))
                    assert isinstance(v, Tensor)
                    v_shape = [1] * total_size
                    for k2, size in zip(v.inputs, v.data.shape):
                        v_shape[new_dims[k2]] = size
                    index.append(v.data.reshape(tuple(v_shape)))
            else:
                # Construct a [:] slice for this preserved input.
                offset_from_right = -1 - new_dims[k]
                index.append(
                    ops.new_arange(
                        self.data,
                        domain.dtype).reshape((-1, ) +
                                              (1, ) * offset_from_right))

        # Construct a [:] slice for the output.
        for i, size in enumerate(self.output.shape):
            offset_from_right = len(self.output.shape) - i - 1
            index.append(
                ops.new_arange(self.data,
                               size).reshape((-1, ) +
                                             (1, ) * offset_from_right))

        data = self.data[tuple(index)]
        return Tensor(data, inputs, self.dtype)
示例#7
0
def _scatter(src, res, subs):
    # inverse of advanced indexing
    # TODO check types of subs, in case some logic from eager_subs was accidentally left out?

    # use advanced indexing logic copied from Tensor.eager_subs:

    # materialize after checking for renaming case
    subs = OrderedDict((k, res.materialize(v)) for k, v in subs)

    # Compute result shapes.
    inputs = OrderedDict()
    for k, domain in res.inputs.items():
        inputs[k] = domain

    # Construct a dict with each input's positional dim,
    # counting from the right so as to support broadcasting.
    total_size = len(inputs) + len(
        res.output.shape)  # Assumes only scalar indices.
    new_dims = {}
    for k, domain in inputs.items():
        assert not domain.shape
        new_dims[k] = len(new_dims) - total_size

    # Use advanced indexing to construct a simultaneous substitution.
    index = []
    for k, domain in res.inputs.items():
        if k in subs:
            v = subs.get(k)
            if isinstance(v, Number):
                index.append(int(v.data))
            else:
                # Permute and expand v.data to end up at new_dims.
                assert isinstance(v, Tensor)
                v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs))
                assert isinstance(v, Tensor)
                v_shape = [1] * total_size
                for k2, size in zip(v.inputs, v.data.shape):
                    v_shape[new_dims[k2]] = size
                index.append(v.data.reshape(tuple(v_shape)))
        else:
            # Construct a [:] slice for this preserved input.
            offset_from_right = -1 - new_dims[k]
            index.append(
                ops.new_arange(
                    res.data,
                    domain.dtype).reshape((-1, ) + (1, ) * offset_from_right))

    # Construct a [:] slice for the output.
    for i, size in enumerate(res.output.shape):
        offset_from_right = len(res.output.shape) - i - 1
        index.append(
            ops.new_arange(res.data,
                           size).reshape((-1, ) + (1, ) * offset_from_right))

    # the only difference from Tensor.eager_subs is here:
    # instead of indexing the rhs (lhs = rhs[index]), we index the lhs (lhs[index] = rhs)

    # unsqueeze to make broadcasting work
    src_inputs, src_data = src.inputs.copy(), src.data
    for k, v in res.inputs.items():
        if k not in src.inputs and isinstance(subs[k], Number):
            src_inputs[k] = bint(1)
            src_data = src_data.unsqueeze(-1 - len(src.output.shape))
    src = Tensor(src_data, src_inputs,
                 src.output.dtype).align(tuple(res.inputs.keys()))

    data = res.data
    data[tuple(index)] = src.data
    return Tensor(data, inputs, res.dtype)