def eager_delta(v, log_density, value):
    # This handles event_dim specially, and hence cannot use the
    # generic Delta.eager_log_prob() method.
    assert v.output == value.output
    event_dim = len(v.output.shape)
    inputs, (v, log_density, value) = align_tensors(v, log_density, value)
    data = dist.Delta(v, log_density, event_dim).log_prob(value)
    return Tensor(data, inputs)
Beispiel #2
0
def eager_multinomial(total_count, probs, value):
    # Multinomial.log_prob() supports inhomogeneous total_count only by
    # avoiding passing total_count to the constructor.
    inputs, (total_count, probs, value) = align_tensors(total_count, probs, value)
    shape = broadcast_shape(total_count.shape + (1,), probs.shape, value.shape)
    probs = Tensor(probs.expand(shape), inputs)
    value = Tensor(value.expand(shape), inputs)
    total_count = Number(total_count.max().item())  # Used by distributions validation code.
    return Multinomial.eager_log_prob(total_count=total_count, probs=probs, value=value)
Beispiel #3
0
def eager_mvn(loc, scale_tril, value):
    if isinstance(loc, Variable):
        loc, value = value, loc

    dim, = loc.output.shape
    inputs, (loc, scale_tril) = align_tensors(loc, scale_tril)
    inputs.update(value.inputs)
    int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')

    log_prob = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)
    inv_scale_tril = torch.inverse(scale_tril)
    precision = torch.matmul(inv_scale_tril.transpose(-1, -2), inv_scale_tril)
    return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
Beispiel #4
0
def eager_normal(loc, scale, value):
    if isinstance(loc, Variable):
        loc, value = value, loc

    inputs, (loc, scale) = align_tensors(loc, scale)
    loc, scale = torch.broadcast_tensors(loc, scale)
    inputs.update(value.inputs)
    int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')

    log_prob = -0.5 * math.log(2 * math.pi) - scale.log()
    loc = loc.unsqueeze(-1)
    precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1)
    return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
def eager_normal(loc, scale, value):
    if isinstance(loc, Variable):
        loc, value = value, loc

    inputs, (loc, scale) = align_tensors(loc, scale, expand=True)
    inputs.update(value.inputs)
    int_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype != 'real')

    precision = scale.pow(-2)
    info_vec = (precision * loc).unsqueeze(-1)
    precision = precision.unsqueeze(-1).unsqueeze(-1)
    log_prob = -0.5 * math.log(
        2 * math.pi) - scale.log() - 0.5 * (loc * info_vec).squeeze(-1)
    return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs)
Beispiel #6
0
def test_batched_einsum(equation, batch1, batch2):
    inputs, output = equation.split('->')
    inputs = inputs.split(',')

    sizes = dict(a=2, b=3, c=4, i=5, j=6)
    batch1 = OrderedDict([(k, bint(sizes[k])) for k in batch1])
    batch2 = OrderedDict([(k, bint(sizes[k])) for k in batch2])
    funsors = [
        random_tensor(batch, reals(*(sizes[d] for d in dims)))
        for batch, dims in zip([batch1, batch2], inputs)
    ]
    actual = Einsum(equation, tuple(funsors))

    _equation = ','.join('...' + i for i in inputs) + '->...' + output
    inputs, tensors = align_tensors(*funsors)
    batch = tuple(v.size for v in inputs.values())
    tensors = [x.expand(batch + f.shape) for (x, f) in zip(tensors, funsors)]
    expected = Tensor(torch.einsum(_equation, tensors), inputs)
    assert_close(actual, expected, atol=1e-5, rtol=None)
Beispiel #7
0
def test_binary_funsor_funsor(symbol, dims1, dims2):
    sizes = {'a': 3, 'b': 4, 'c': 5}
    shape1 = tuple(sizes[d] for d in dims1)
    shape2 = tuple(sizes[d] for d in dims2)
    inputs1 = OrderedDict((d, bint(sizes[d])) for d in dims1)
    inputs2 = OrderedDict((d, bint(sizes[d])) for d in dims2)
    data1 = torch.rand(shape1) + 0.5
    data2 = torch.rand(shape2) + 0.5
    dtype = 'real'
    if symbol in BOOLEAN_OPS:
        dtype = 2
        data1 = data1.byte()
        data2 = data2.byte()
    x1 = Tensor(data1, inputs1, dtype)
    x2 = Tensor(data2, inputs2, dtype)
    inputs, aligned = align_tensors(x1, x2)
    expected_data = binary_eval(symbol, aligned[0], aligned[1])

    actual = binary_eval(symbol, x1, x2)
    check_funsor(actual, inputs, Domain((), dtype), expected_data)
Beispiel #8
0
def test_tensor_distribution(event_inputs, batch_inputs, test_grad):
    num_samples = 50000
    sample_inputs = OrderedDict(n=bint(num_samples))
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    sampled_vars = frozenset(event_inputs)
    p = random_tensor(be_inputs)
    p.data.requires_grad_(test_grad)

    q = p.sample(sampled_vars, sample_inputs)
    mq = materialize(q).reduce(ops.logaddexp, 'n')
    mq = mq.align(tuple(p.inputs))
    assert_close(mq, p, atol=0.1, rtol=None)

    if test_grad:
        _, (p_data, mq_data) = align_tensors(p, mq)
        assert p_data.shape == mq_data.shape
        probe = torch.randn(p_data.shape)
        expected = grad((p_data.exp() * probe).sum(), [p.data])[0]
        actual = grad((mq_data.exp() * probe).sum(), [p.data])[0]
        assert_close(actual, expected, atol=0.1, rtol=None)
Beispiel #9
0
def eager_normal(loc, scale, value):
    affine = (loc - value) / scale
    assert isinstance(affine, Affine)
    real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
    assert not any(v.shape for v in real_inputs.values())

    tensors = [affine.const] + [c for v, c in affine.coeffs.items()]
    inputs, tensors = align_tensors(*tensors)
    tensors = torch.broadcast_tensors(*tensors)
    const, coeffs = tensors[0], tensors[1:]

    dim = sum(d.num_elements for d in real_inputs.values())
    loc = BlockVector(const.shape + (dim,))
    loc[..., 0] = -const / coeffs[0]
    precision = BlockMatrix(const.shape + (dim, dim))
    for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
        for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
            precision[..., i, j] = c1 * c2
    loc = loc.as_tensor()
    precision = precision.as_tensor()

    log_prob = -0.5 * math.log(2 * math.pi) - scale.log()
    return log_prob + Gaussian(loc, precision, affine.inputs)
def eager_normal(loc, scale, value):
    affine = (loc - value) / scale
    if not affine.is_affine:
        return None

    real_inputs = OrderedDict(
        (k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
    int_inputs = OrderedDict(
        (k, v) for k, v in affine.inputs.items() if v.dtype != 'real')
    assert not any(v.shape for v in real_inputs.values())

    const = affine(**{k: 0. for k, v in real_inputs.items()})
    coeffs = OrderedDict()
    for c in real_inputs.keys():
        coeffs[c] = affine(
            **{k: 1. if c == k else 0.
               for k in real_inputs.keys()}) - const

    tensors = [const] + list(coeffs.values())
    inputs, tensors = align_tensors(*tensors, expand=True)
    const, coeffs = tensors[0], tensors[1:]

    dim = sum(d.num_elements for d in real_inputs.values())
    loc = BlockVector(const.shape + (dim, ))
    loc[..., 0] = -const / coeffs[0]
    precision = BlockMatrix(const.shape + (dim, dim))
    for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
        for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
            precision[..., i, j] = c1 * c2
    loc = loc.as_tensor()
    precision = precision.as_tensor()
    info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1)

    log_prob = -0.5 * math.log(
        2 * math.pi) - scale.data.log() - 0.5 * (loc * info_vec).sum(-1)
    return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision,
                                                   affine.inputs)
 def eager_log_prob(cls, **params):
     inputs, tensors = align_tensors(*params.values())
     params = dict(zip(params, tensors))
     value = params.pop('value')
     data = cls.dist_class(**params).log_prob(value)
     return Tensor(data, inputs)
def eager_mvn(loc, scale_tril, value):
    assert len(loc.shape) == 1
    assert len(scale_tril.shape) == 2
    assert value.output == loc.output
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    # Extract an affine representation.
    eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape)
    prec_sqrt = Tensor(
        eye.triangular_solve(scale_tril.data, upper=False).solution,
        scale_tril.inputs)
    affine = prec_sqrt @ (loc - value)
    const, coeffs = extract_affine(affine)
    if not isinstance(const, Tensor):
        return None  # lazy
    if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()):
        return None  # lazy

    # Compute log_prob using funsors.
    scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2),
                        scale_tril.inputs)
    log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) -
                scale_diag.log().sum() - 0.5 * (const**2).sum())

    # Dovetail to avoid variable name collision in einsum.
    equations1 = [
        ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn)
        for _, eqn in coeffs.values()
    ]
    equations2 = [
        ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1)
                for c in eqn) for _, eqn in coeffs.values()
    ]

    real_inputs = OrderedDict(
        (k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
    assert tuple(real_inputs) == tuple(coeffs)

    # Align and broadcast tensors.
    neg_const = -const
    tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()]
    inputs, tensors = align_tensors(*tensors, expand=True)
    neg_const, coeffs = tensors[0], tensors[1:]
    dim = sum(d.num_elements for d in real_inputs.values())
    batch_shape = neg_const.shape[:-1]

    info_vec = BlockVector(batch_shape + (dim, ))
    precision = BlockMatrix(batch_shape + (dim, dim))
    offset1 = 0
    for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
        size1 = real_inputs[v1].num_elements
        slice1 = slice(offset1, offset1 + size1)
        inputs1, output1 = equations1[i1].split('->')
        input11, input12 = inputs1.split(',')
        assert input11 == input12 + output1
        info_vec[..., slice1] = torch.einsum(
            f'...{input11},...{output1}->...{input12}', c1, neg_const) \
            .reshape(batch_shape + (size1,))
        offset2 = 0
        for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
            size2 = real_inputs[v2].num_elements
            slice2 = slice(offset2, offset2 + size2)
            inputs2, output2 = equations2[i2].split('->')
            input21, input22 = inputs2.split(',')
            assert input21 == input22 + output2
            precision[..., slice1, slice2] = torch.einsum(
                f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \
                .reshape(batch_shape + (size1, size2))
            offset2 += size2
        offset1 += size1

    info_vec = info_vec.as_tensor()
    precision = precision.as_tensor()
    inputs.update(real_inputs)
    return log_prob + Gaussian(info_vec, precision, inputs)
Beispiel #13
0
    def _eager_subs_affine(self, subs, remaining_subs):
        # Extract an affine representation.
        affine = OrderedDict()
        for k, v in subs:
            const, coeffs = extract_affine(v)
            if (isinstance(const, Tensor) and all(
                    isinstance(coeff, Tensor)
                    for coeff, _ in coeffs.values())):
                affine[k] = const, coeffs
            else:
                remaining_subs += (k, v),
        if not affine:
            return reflect(Subs, self, remaining_subs)

        # Align integer dimensions.
        old_int_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, old_int_inputs),
            Tensor(self.precision, old_int_inputs)
        ]
        for const, coeffs in affine.values():
            tensors.append(const)
            tensors.extend(coeff for coeff, _ in coeffs.values())
        new_int_inputs, tensors = align_tensors(*tensors, expand=True)
        tensors = (Tensor(x, new_int_inputs) for x in tensors)
        old_info_vec = next(tensors).data
        old_precision = next(tensors).data
        for old_k, (const, coeffs) in affine.items():
            const = next(tensors)
            for new_k, (coeff, eqn) in coeffs.items():
                coeff = next(tensors)
                coeffs[new_k] = coeff, eqn
            affine[old_k] = const, coeffs
        batch_shape = old_info_vec.data.shape[:-1]

        # Align real dimensions.
        old_real_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype == 'real')
        new_real_inputs = old_real_inputs.copy()
        for old_k, (const, coeffs) in affine.items():
            del new_real_inputs[old_k]
            for new_k, (coeff, eqn) in coeffs.items():
                new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])]
                new_real_inputs[new_k] = reals(*new_shape)
        old_offsets, old_dim = _compute_offsets(old_real_inputs)
        new_offsets, new_dim = _compute_offsets(new_real_inputs)
        new_inputs = new_int_inputs.copy()
        new_inputs.update(new_real_inputs)

        # Construct a blockwise affine representation of the substitution.
        subs_vector = BlockVector(batch_shape + (old_dim, ))
        subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim))
        for old_k, old_offset in old_offsets.items():
            old_size = old_real_inputs[old_k].num_elements
            old_slice = slice(old_offset, old_offset + old_size)
            if old_k in new_real_inputs:
                new_offset = new_offsets[old_k]
                new_slice = slice(new_offset, new_offset + old_size)
                subs_matrix[..., new_slice, old_slice] = \
                    torch.eye(old_size).expand(batch_shape + (-1, -1))
                continue
            const, coeffs = affine[old_k]
            old_shape = old_real_inputs[old_k].shape
            assert const.data.shape == batch_shape + old_shape
            subs_vector[..., old_slice] = const.data.reshape(batch_shape +
                                                             (old_size, ))
            for new_k, new_offset in new_offsets.items():
                if new_k in coeffs:
                    coeff, eqn = coeffs[new_k]
                    new_size = new_real_inputs[new_k].num_elements
                    new_slice = slice(new_offset, new_offset + new_size)
                    assert coeff.shape == new_real_inputs[
                        new_k].shape + old_shape
                    subs_matrix[..., new_slice, old_slice] = \
                        coeff.data.reshape(batch_shape + (new_size, old_size))
        subs_vector = subs_vector.as_tensor()
        subs_matrix = subs_matrix.as_tensor()
        subs_matrix_t = subs_matrix.transpose(-1, -2)

        # Construct the new funsor. Suppose the old Gaussian funsor g has density
        #   g(x) = < x | i - 1/2 P x>
        # Now define a new funsor f by substituting x = A y + B:
        #   f(y) = g(A y + B)
        #        = < A y + B | i - 1/2 P (A y + B) >
        #        = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B >
        #        =: < y | i' - 1/2 P' y > + C
        # where  P' = At P A  and  i' = At (i - P B)  parametrize a new Gaussian
        # and  C = < B | i - 1/2 P B >  parametrize a new Tensor.
        precision = subs_matrix @ old_precision @ subs_matrix_t
        info_vec = _mv(subs_matrix,
                       old_info_vec - _mv(old_precision, subs_vector))
        const = _vv(subs_vector,
                    old_info_vec - 0.5 * _mv(old_precision, subs_vector))
        result = Gaussian(info_vec, precision, new_inputs) + Tensor(
            const, new_int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Beispiel #14
0
    def _eager_subs_real(self, subs, remaining_subs):
        # Broadcast all component tensors.
        subs = OrderedDict(subs)
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = tensors[0].dim() - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not torch._C._get_tracing_state():
                assert value.size(-1) == self.inputs[k].num_elements
            values[k] = value.expand(batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == reals()
            return Subs(result, remaining_subs) if remaining_subs else result

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_ab = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_bb = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in b
        ],
                            dim=-2)
        info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a],
                           dim=-1)
        info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b],
                           dim=-1)
        value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1)
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = prec_aa.expand(info_vec.shape + (-1, ))
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in subs:
                inputs[k] = d
        result = Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Beispiel #15
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple((k, materialize(to_funsor(v, self.inputs[k])))
                     for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Variables are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple((k, v) for k, v in subs
                          if not isinstance(v, (Number, Tensor, Variable)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        if not (var_subs or int_subs or real_subs):
            return reflect(Subs, self, lazy_subs)

        # First perform any variable substitutions.
        if var_subs:
            rename = {k: v.name for k, v in var_subs}
            inputs = OrderedDict(
                (rename.get(k, k), d) for k, d in self.inputs.items())
            if len(inputs) != len(self.inputs):
                raise ValueError("Variable substitution name conflict")
            var_result = Gaussian(self.loc, self.precision, inputs)
            return Subs(var_result, int_subs + real_subs + lazy_subs)

        # Next perform any integer substitution, i.e. slicing into a batch.
        if int_subs:
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            real_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
            tensors = [self.loc, self.precision]
            funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
            inputs = funsors[0].inputs.copy()
            inputs.update(real_inputs)
            int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
            return Subs(int_result, real_subs + lazy_subs)

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        real_subs = OrderedDict(subs)
        assert real_subs and not int_subs
        if all(k in real_subs for k, d in self.inputs.items()
               if d.dtype == 'real'):
            # Broadcast all component tensors.
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            tensors = [
                Tensor(self.loc, int_inputs),
                Tensor(self.precision, int_inputs)
            ]
            tensors.extend(real_subs.values())
            inputs, tensors = align_tensors(*tensors)
            batch_dim = tensors[0].dim() - 1
            batch_shape = broadcast_shape(*(x.shape[:batch_dim]
                                            for x in tensors))
            (loc, precision), values = tensors[:2], tensors[2:]

            # Form the concatenated value.
            offsets, event_size = _compute_offsets(self.inputs)
            value = BlockVector(batch_shape + (event_size, ))
            for k, value_k in zip(real_subs, values):
                offset = offsets[k]
                value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1, ))
                if not torch._C._get_tracing_state():
                    assert value_k.size(-1) == self.inputs[k].num_elements
                value_k = value_k.expand(batch_shape + value_k.shape[-1:])
                value[...,
                      offset:offset + self.inputs[k].num_elements] = value_k
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = -0.5 * _vmv(precision, value - loc)
            result = Tensor(result, inputs)
            assert result.output == reals()
            return Subs(result, lazy_subs)

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353.
        # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
        raise NotImplementedError(
            'TODO implement partial substitution of real variables')
Beispiel #16
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple(
            (k, v if isinstance(v, (Variable, Slice)) else materialize(v))
            for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Variables are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple(
            (k, v) for k, v in subs
            if not isinstance(v, (Number, Tensor, Variable, Slice)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor, Slice))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        if not (var_subs or int_subs or real_subs):
            return reflect(Subs, self, lazy_subs)

        # First perform any variable substitutions.
        if var_subs:
            rename = {k: v.name for k, v in var_subs}
            inputs = OrderedDict(
                (rename.get(k, k), d) for k, d in self.inputs.items())
            if len(inputs) != len(self.inputs):
                raise ValueError("Variable substitution name conflict")
            var_result = Gaussian(self.info_vec, self.precision, inputs)
            return Subs(var_result, int_subs + real_subs + lazy_subs)

        # Next perform any integer substitution, i.e. slicing into a batch.
        if int_subs:
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            real_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
            tensors = [self.info_vec, self.precision]
            funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
            inputs = funsors[0].inputs.copy()
            inputs.update(real_inputs)
            int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
            return Subs(int_result, real_subs + lazy_subs)

        # Broadcast all component tensors.
        real_subs = OrderedDict(subs)
        assert real_subs and not int_subs
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(real_subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = tensors[0].dim() - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(real_subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not torch._C._get_tracing_state():
                assert value.size(-1) == self.inputs[k].num_elements
            values[k] = value.expand(batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in real_subs for k, d in self.inputs.items()
               if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == reals()
            return Subs(result, lazy_subs)

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in real_subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_ab = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_bb = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in b
        ],
                            dim=-2)
        info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a],
                           dim=-1)
        info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b],
                           dim=-1)
        value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1)
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = prec_aa.expand(info_vec.shape + (-1, ))
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in real_subs:
                inputs[k] = d
        return Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)