示例#1
0
def eager_cat_homogeneous(name, part_name, *parts):
    assert parts
    output = parts[0].output
    inputs = OrderedDict([(part_name, None)])
    for part in parts:
        assert part.output == output
        assert part_name in part.inputs
        inputs.update(part.inputs)

    int_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype != "real")
    real_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype == "real")
    inputs = int_inputs.copy()
    inputs.update(real_inputs)
    discretes = []
    info_vecs = []
    precisions = []
    for part in parts:
        inputs[part_name] = part.inputs[part_name]
        int_inputs[part_name] = inputs[part_name]
        shape = tuple(d.size for d in int_inputs.values())
        if isinstance(part, Gaussian):
            discrete = None
            gaussian = part
        elif issubclass(type(part), GaussianMixture
                        ):  # TODO figure out why isinstance isn't working
            discrete, gaussian = part.terms[0], part.terms[1]
            discrete = ops.expand(align_tensor(int_inputs, discrete), shape)
        else:
            raise NotImplementedError("TODO")
        discretes.append(discrete)
        info_vec, precision = align_gaussian(inputs, gaussian)
        info_vecs.append(ops.expand(info_vec, shape + (-1, )))
        precisions.append(ops.expand(precision, shape + (-1, -1)))
    if part_name != name:
        del inputs[part_name]
        del int_inputs[part_name]

    dim = 0
    info_vec = ops.cat(dim, *info_vecs)
    precision = ops.cat(dim, *precisions)
    inputs[name] = Bint[info_vec.shape[dim]]
    int_inputs[name] = inputs[name]
    result = Gaussian(info_vec, precision, inputs)
    if any(d is not None for d in discretes):
        for i, d in enumerate(discretes):
            if d is None:
                discretes[i] = ops.new_zeros(info_vecs[i],
                                             info_vecs[i].shape[:-1])
        discrete = ops.cat(dim, *discretes)
        result = result + Tensor(discrete, int_inputs)
    return result
示例#2
0
    def as_tensor(self):
        # Fill gaps with zeros.
        prototype = next(iter(self.parts.values()))
        for i in _find_intervals(self.parts.keys(), self.shape[-1]):
            if i not in self.parts:
                self.parts[i] = ops.new_zeros(
                    prototype, self.shape[:-1] + (i[1] - i[0], ))

        # Concatenate parts.
        parts = [v for k, v in sorted(self.parts.items())]
        result = ops.cat(-1, *parts)
        if not get_tracing_state():
            assert result.shape == self.shape
        return result
示例#3
0
    def as_tensor(self):
        # Fill gaps with zeros.
        arbitrary_row = next(iter(self.parts.values()))
        prototype = next(iter(arbitrary_row.values()))
        js = set().union(*(part.keys() for part in self.parts.values()))
        rows = _find_intervals(self.parts.keys(), self.shape[-2])
        cols = _find_intervals(js, self.shape[-1])
        for i in rows:
            for j in cols:
                if j not in self.parts[i]:
                    shape = self.shape[:-2] + (i[1] - i[0], j[1] - j[0])
                    self.parts[i][j] = ops.new_zeros(prototype, shape)

        # Concatenate parts.
        # TODO This could be optimized into a single .reshape().cat().reshape() if
        #   all inputs are contiguous, thereby saving a memcopy.
        columns = {
            i: ops.cat(-1, *[v for j, v in sorted(part.items())])
            for i, part in self.parts.items()
        }
        result = ops.cat(-2, *[v for i, v in sorted(columns.items())])
        if not get_tracing_state():
            assert result.shape == self.shape
        return result
示例#4
0
def eager_cat_homogeneous(name, part_name, *parts):
    assert parts
    output = parts[0].output
    inputs = OrderedDict([(part_name, None)])
    for part in parts:
        assert part.output == output
        assert part_name in part.inputs
        inputs.update(part.inputs)

    tensors = []
    for part in parts:
        inputs[part_name] = part.inputs[part_name]
        shape = tuple(d.size for d in inputs.values()) + output.shape
        tensors.append(ops.expand(align_tensor(inputs, part), shape))
    del inputs[part_name]

    dim = 0
    tensor = ops.cat(dim, *tensors)
    inputs = OrderedDict([(name, Bint[tensor.shape[dim]])] +
                         list(inputs.items()))
    return Tensor(tensor, inputs, dtype=output.dtype)
示例#5
0
    def eager_reduce(self, op, reduced_vars):
        if op is ops.logaddexp:
            # Marginalize out real variables, but keep mixtures lazy.
            assert all(v in self.inputs for v in reduced_vars)
            real_vars = frozenset(k for k, d in self.inputs.items()
                                  if d.dtype == "real")
            reduced_reals = reduced_vars & real_vars
            reduced_ints = reduced_vars - real_vars
            if not reduced_reals:
                return None  # defer to default implementation

            inputs = OrderedDict((k, d) for k, d in self.inputs.items()
                                 if k not in reduced_reals)
            if reduced_reals == real_vars:
                result = self.log_normalizer
            else:
                int_inputs = OrderedDict(
                    (k, v) for k, v in inputs.items() if v.dtype != 'real')
                offsets, _ = _compute_offsets(self.inputs)
                a = []
                b = []
                for key, domain in self.inputs.items():
                    if domain.dtype == 'real':
                        block = ops.new_arange(
                            self.info_vec, offsets[key],
                            offsets[key] + domain.num_elements, 1)
                        (b if key in reduced_vars else a).append(block)
                a = ops.cat(-1, *a)
                b = ops.cat(-1, *b)
                prec_aa = self.precision[..., a[..., None], a]
                prec_ba = self.precision[..., b[..., None], a]
                prec_bb = self.precision[..., b[..., None], b]
                prec_b = ops.cholesky(prec_bb)
                prec_a = ops.triangular_solve(prec_ba, prec_b)
                prec_at = ops.transpose(prec_a, -1, -2)
                precision = prec_aa - ops.matmul(prec_at, prec_a)

                info_a = self.info_vec[..., a]
                info_b = self.info_vec[..., b]
                b_tmp = ops.triangular_solve(info_b[..., None], prec_b)
                info_vec = info_a - ops.matmul(prec_at, b_tmp)[..., 0]

                log_prob = Tensor(
                    0.5 * len(b) * math.log(2 * math.pi) -
                    _log_det_tri(prec_b) + 0.5 * (b_tmp[..., 0]**2).sum(-1),
                    int_inputs)
                result = log_prob + Gaussian(info_vec, precision, inputs)

            return result.reduce(ops.logaddexp, reduced_ints)

        elif op is ops.add:
            for v in reduced_vars:
                if self.inputs[v].dtype == 'real':
                    raise ValueError(
                        "Cannot sum along a real dimension: {}".format(
                            repr(v)))

            # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian().
            old_ints = OrderedDict(
                (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
            new_ints = OrderedDict(
                (k, v) for k, v in old_ints.items() if k not in reduced_vars)
            inputs = OrderedDict((k, v) for k, v in self.inputs.items()
                                 if k not in reduced_vars)

            info_vec = Tensor(self.info_vec,
                              old_ints).reduce(ops.add, reduced_vars)
            precision = Tensor(self.precision,
                               old_ints).reduce(ops.add, reduced_vars)
            assert info_vec.inputs == new_ints
            assert precision.inputs == new_ints
            return Gaussian(info_vec.data, precision.data, inputs)

        return None  # defer to default implementation
示例#6
0
    def _eager_subs_real(self, subs, remaining_subs):
        # Broadcast all component tensors.
        subs = OrderedDict(subs)
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = len(tensors[0].shape) - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not get_tracing_state():
                assert value.shape[-1] == self.inputs[k].num_elements
            values[k] = ops.expand(value, batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == Real
            return Subs(result, remaining_subs) if remaining_subs else result

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in a])
                for k1, i1 in slices if k1 in a
            ])
        prec_ab = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in a
            ])
        prec_bb = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in b
            ])
        info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a])
        info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b])
        value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b])
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:])
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in subs:
                inputs[k] = d
        result = Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result