示例#1
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple(
            (k, v if isinstance(v, (Variable, Slice)) else materialize(v))
            for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Affine funsors are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple(
            (k, v) for k, v in subs
            if not isinstance(v, (Number, Tensor, Variable, Slice))
            and not (is_affine(v) and affine_inputs(v)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor, Slice))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        affine_subs = tuple((k, v) for k, v in subs
                            if is_affine(v) and affine_inputs(v)
                            and not isinstance(v, Variable))
        if var_subs:
            return self._eager_subs_var(
                var_subs, int_subs + real_subs + affine_subs + lazy_subs)
        if int_subs:
            return self._eager_subs_int(int_subs,
                                        real_subs + affine_subs + lazy_subs)
        if real_subs:
            return self._eager_subs_real(real_subs, affine_subs + lazy_subs)
        if affine_subs:
            return self._eager_subs_affine(affine_subs, lazy_subs)
        return reflect(Subs, self, lazy_subs)
示例#2
0
def normalize_with_subs(cls, *args):
    """
    This interpretation is like normalize, except it also evaluates Subs eagerly.
    This is necessary because we want to convert distribution expressions to normal form in some tests,
    but do not want to trigger eager patterns that rewrite some distributions (e.g. Normal to Gaussian)
    since these tests are specifically intended to exercise funsor.distribution.Distribution.
    """
    result = normalize.dispatch(cls, *args)(*args)
    if result is None:
        result = lazy.dispatch(cls, *args)(*args)
    if result is None:
        result = reflect(cls, *args)
    return result
示例#3
0
    def _eager_subs_affine(self, subs, remaining_subs):
        # Extract an affine representation.
        affine = OrderedDict()
        for k, v in subs:
            const, coeffs = extract_affine(v)
            if (isinstance(const, Tensor) and all(
                    isinstance(coeff, Tensor)
                    for coeff, _ in coeffs.values())):
                affine[k] = const, coeffs
            else:
                remaining_subs += (k, v),
        if not affine:
            return reflect(Subs, self, remaining_subs)

        # Align integer dimensions.
        old_int_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, old_int_inputs),
            Tensor(self.precision, old_int_inputs)
        ]
        for const, coeffs in affine.values():
            tensors.append(const)
            tensors.extend(coeff for coeff, _ in coeffs.values())
        new_int_inputs, tensors = align_tensors(*tensors, expand=True)
        tensors = (Tensor(x, new_int_inputs) for x in tensors)
        old_info_vec = next(tensors).data
        old_precision = next(tensors).data
        for old_k, (const, coeffs) in affine.items():
            const = next(tensors)
            for new_k, (coeff, eqn) in coeffs.items():
                coeff = next(tensors)
                coeffs[new_k] = coeff, eqn
            affine[old_k] = const, coeffs
        batch_shape = old_info_vec.shape[:-1]

        # Align real dimensions.
        old_real_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype == 'real')
        new_real_inputs = old_real_inputs.copy()
        for old_k, (const, coeffs) in affine.items():
            del new_real_inputs[old_k]
            for new_k, (coeff, eqn) in coeffs.items():
                new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])]
                new_real_inputs[new_k] = Reals[new_shape]
        old_offsets, old_dim = _compute_offsets(old_real_inputs)
        new_offsets, new_dim = _compute_offsets(new_real_inputs)
        new_inputs = new_int_inputs.copy()
        new_inputs.update(new_real_inputs)

        # Construct a blockwise affine representation of the substitution.
        subs_vector = BlockVector(batch_shape + (old_dim, ))
        subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim))
        for old_k, old_offset in old_offsets.items():
            old_size = old_real_inputs[old_k].num_elements
            old_slice = slice(old_offset, old_offset + old_size)
            if old_k in new_real_inputs:
                new_offset = new_offsets[old_k]
                new_slice = slice(new_offset, new_offset + old_size)
                subs_matrix[..., new_slice, old_slice] = \
                    ops.new_eye(self.info_vec, batch_shape + (old_size,))
                continue
            const, coeffs = affine[old_k]
            old_shape = old_real_inputs[old_k].shape
            assert const.data.shape == batch_shape + old_shape
            subs_vector[..., old_slice] = const.data.reshape(batch_shape +
                                                             (old_size, ))
            for new_k, new_offset in new_offsets.items():
                if new_k in coeffs:
                    coeff, eqn = coeffs[new_k]
                    new_size = new_real_inputs[new_k].num_elements
                    new_slice = slice(new_offset, new_offset + new_size)
                    assert coeff.shape == new_real_inputs[
                        new_k].shape + old_shape
                    subs_matrix[..., new_slice, old_slice] = \
                        coeff.data.reshape(batch_shape + (new_size, old_size))
        subs_vector = subs_vector.as_tensor()
        subs_matrix = subs_matrix.as_tensor()
        subs_matrix_t = ops.transpose(subs_matrix, -1, -2)

        # Construct the new funsor. Suppose the old Gaussian funsor g has density
        #   g(x) = < x | i - 1/2 P x>
        # Now define a new funsor f by substituting x = A y + B:
        #   f(y) = g(A y + B)
        #        = < A y + B | i - 1/2 P (A y + B) >
        #        = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B >
        #        =: < y | i' - 1/2 P' y > + C
        # where  P' = At P A  and  i' = At (i - P B)  parametrize a new Gaussian
        # and  C = < B | i - 1/2 P B >  parametrize a new Tensor.
        precision = subs_matrix @ old_precision @ subs_matrix_t
        info_vec = _mv(subs_matrix,
                       old_info_vec - _mv(old_precision, subs_vector))
        const = _vv(subs_vector,
                    old_info_vec - 0.5 * _mv(old_precision, subs_vector))
        result = Gaussian(info_vec, precision, new_inputs) + Tensor(
            const, new_int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
示例#4
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple((k, materialize(to_funsor(v, self.inputs[k])))
                     for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Variables are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple((k, v) for k, v in subs
                          if not isinstance(v, (Number, Tensor, Variable)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        if not (var_subs or int_subs or real_subs):
            return reflect(Subs, self, lazy_subs)

        # First perform any variable substitutions.
        if var_subs:
            rename = {k: v.name for k, v in var_subs}
            inputs = OrderedDict(
                (rename.get(k, k), d) for k, d in self.inputs.items())
            if len(inputs) != len(self.inputs):
                raise ValueError("Variable substitution name conflict")
            var_result = Gaussian(self.loc, self.precision, inputs)
            return Subs(var_result, int_subs + real_subs + lazy_subs)

        # Next perform any integer substitution, i.e. slicing into a batch.
        if int_subs:
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            real_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
            tensors = [self.loc, self.precision]
            funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
            inputs = funsors[0].inputs.copy()
            inputs.update(real_inputs)
            int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
            return Subs(int_result, real_subs + lazy_subs)

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        real_subs = OrderedDict(subs)
        assert real_subs and not int_subs
        if all(k in real_subs for k, d in self.inputs.items()
               if d.dtype == 'real'):
            # Broadcast all component tensors.
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            tensors = [
                Tensor(self.loc, int_inputs),
                Tensor(self.precision, int_inputs)
            ]
            tensors.extend(real_subs.values())
            inputs, tensors = align_tensors(*tensors)
            batch_dim = tensors[0].dim() - 1
            batch_shape = broadcast_shape(*(x.shape[:batch_dim]
                                            for x in tensors))
            (loc, precision), values = tensors[:2], tensors[2:]

            # Form the concatenated value.
            offsets, event_size = _compute_offsets(self.inputs)
            value = BlockVector(batch_shape + (event_size, ))
            for k, value_k in zip(real_subs, values):
                offset = offsets[k]
                value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1, ))
                if not torch._C._get_tracing_state():
                    assert value_k.size(-1) == self.inputs[k].num_elements
                value_k = value_k.expand(batch_shape + value_k.shape[-1:])
                value[...,
                      offset:offset + self.inputs[k].num_elements] = value_k
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = -0.5 * _vmv(precision, value - loc)
            result = Tensor(result, inputs)
            assert result.output == reals()
            return Subs(result, lazy_subs)

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353.
        # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
        raise NotImplementedError(
            'TODO implement partial substitution of real variables')
示例#5
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple(
            (k, v if isinstance(v, (Variable, Slice)) else materialize(v))
            for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Variables are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple(
            (k, v) for k, v in subs
            if not isinstance(v, (Number, Tensor, Variable, Slice)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor, Slice))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        if not (var_subs or int_subs or real_subs):
            return reflect(Subs, self, lazy_subs)

        # First perform any variable substitutions.
        if var_subs:
            rename = {k: v.name for k, v in var_subs}
            inputs = OrderedDict(
                (rename.get(k, k), d) for k, d in self.inputs.items())
            if len(inputs) != len(self.inputs):
                raise ValueError("Variable substitution name conflict")
            var_result = Gaussian(self.info_vec, self.precision, inputs)
            return Subs(var_result, int_subs + real_subs + lazy_subs)

        # Next perform any integer substitution, i.e. slicing into a batch.
        if int_subs:
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            real_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
            tensors = [self.info_vec, self.precision]
            funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
            inputs = funsors[0].inputs.copy()
            inputs.update(real_inputs)
            int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
            return Subs(int_result, real_subs + lazy_subs)

        # Broadcast all component tensors.
        real_subs = OrderedDict(subs)
        assert real_subs and not int_subs
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(real_subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = tensors[0].dim() - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(real_subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not torch._C._get_tracing_state():
                assert value.size(-1) == self.inputs[k].num_elements
            values[k] = value.expand(batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in real_subs for k, d in self.inputs.items()
               if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == reals()
            return Subs(result, lazy_subs)

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in real_subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_ab = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_bb = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in b
        ],
                            dim=-2)
        info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a],
                           dim=-1)
        info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b],
                           dim=-1)
        value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1)
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = prec_aa.expand(info_vec.shape + (-1, ))
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in real_subs:
                inputs[k] = d
        return Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)