Esempio n. 1
0
def eager_integrate(delta, integrand, reduced_vars):
    if not reduced_vars & delta.fresh:
        return None
    subs = tuple((name, point) for name, (point, log_density) in delta.terms
                 if name in reduced_vars)
    new_integrand = Subs(integrand, subs)
    new_log_measure = Subs(delta, subs)
    result = Integrate(new_log_measure, new_integrand, reduced_vars - delta.fresh)
    return result
Esempio n. 2
0
def adjoint_subs_gaussianmixture_discrete(adj_redop, adj_binop, out_adj, arg,
                                          subs):

    if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs):
        raise NotImplementedError(
            "TODO implement adjoint for substitution into Gaussian real variable"
        )

    # invert renaming
    renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable))
    out_adj = Subs(out_adj, renames)

    # inverting advanced indexing
    slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable))

    arg_int_inputs = OrderedDict(
        (k, v) for k, v in arg.inputs.items() if v.dtype != 'real')

    zeros_like_out = Subs(
        Tensor(
            arg.terms[1].info_vec.new_full(arg.terms[1].info_vec.shape[:-1],
                                           ops.UNITS[adj_binop]),
            arg_int_inputs), slices)
    out_adj = adj_binop(out_adj, zeros_like_out)

    in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj,
                                  arg.terms[0], subs)[arg.terms[0]]

    # invert the slicing for the Gaussian term even though the message does not affect the values
    in_adj_info_vec = list(
        adjoint_ops(
            Subs,
            adj_redop,
            adj_binop,  # ops.add, ops.mul,
            zeros_like_out,
            Tensor(arg.terms[1].info_vec, arg_int_inputs),
            slices).values())[0]

    in_adj_precision = list(
        adjoint_ops(
            Subs,
            adj_redop,
            adj_binop,  # ops.add, ops.mul,
            zeros_like_out,
            Tensor(arg.terms[1].precision, arg_int_inputs),
            slices).values())[0]

    assert isinstance(in_adj_info_vec, Tensor)
    assert isinstance(in_adj_precision, Tensor)

    in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data,
                               arg.inputs.copy())

    in_adj = in_adj_gaussian + in_adj_discrete
    return {arg: in_adj}
Esempio n. 3
0
def eager_add(op, joint, delta):
    # Update with a degenerate distribution, typically a monte carlo sample.
    if delta.name in joint.inputs:
        joint = Subs(joint, ((delta.name, delta.point), ))
        if not isinstance(joint, Joint):
            return joint + delta
    for d in joint.deltas:
        if d.name in delta.inputs:
            delta = Subs(delta, ((d.name, d.point), ))
    deltas = joint.deltas + (delta, )
    return Joint(deltas, joint.discrete, joint.gaussian)
Esempio n. 4
0
 def _eager_subs_int(self, subs, remaining_subs):
     # Perform integer substitution, i.e. slicing into a batch.
     int_inputs = OrderedDict(
         (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
     real_inputs = OrderedDict(
         (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
     tensors = [self.info_vec, self.precision]
     funsors = [Subs(Tensor(x, int_inputs), subs) for x in tensors]
     inputs = funsors[0].inputs.copy()
     inputs.update(real_inputs)
     int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
     return Subs(int_result,
                 remaining_subs) if remaining_subs else int_result
Esempio n. 5
0
def eager_add(op, delta, other):
    if delta.name in other.inputs:
        other = Subs(other, ((delta.name, delta.point), ))
        assert isinstance(other, (Number, Tensor, Gaussian))
    if isinstance(other, (Number, Tensor)):
        return Joint((delta, ), discrete=other)
    else:
        return Joint((delta, ), gaussian=other)
Esempio n. 6
0
def distribute_subs_contraction(arg, subs):
    new_terms = tuple(
        Subs(v, tuple(
            (name, sub) for name, sub in subs
            if name in v.inputs)) if any(name in v.inputs
                                         for name, sub in subs) else v
        for v in arg.terms)
    return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *new_terms)
Esempio n. 7
0
 def _eager_subs_var(self, subs, remaining_subs):
     # Perform variable substitution, i.e. renaming of inputs.
     rename = {k: v.name for k, v in subs}
     inputs = OrderedDict(
         (rename.get(k, k), d) for k, d in self.inputs.items())
     if len(inputs) != len(self.inputs):
         raise ValueError("Variable substitution name conflict")
     var_result = Gaussian(self.info_vec, self.precision, inputs)
     return Subs(var_result,
                 remaining_subs) if remaining_subs else var_result
def test_affine_subs():
    # This was recorded from test_pyro_convert.
    x = Subs(
     Gaussian(
      torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32),  # noqa
      torch.tensor([[1.0199567079544067, 0.9840421676635742, -0.473368763923645, 0.34206756949424744, -0.7562517523765564], [0.9840421676635742, 1.511502742767334, -1.7593903541564941, 0.6647964119911194, -0.5119513273239136], [-0.4733688533306122, -1.7593903541564941, 3.2386727333068848, -0.9345928430557251, -0.1534711718559265], [0.34206756949424744, 0.6647964119911194, -0.9345928430557251, 0.3141004145145416, -0.12399007380008698], [-0.7562517523765564, -0.5119513273239136, -0.1534711718559265, -0.12399007380008698, 0.6450173854827881]], dtype=torch.float32),  # noqa
      (('state_1_b6',
        reals(3,),),
       ('obs_b2',
        reals(2,),),)),
     (('obs_b2',
       Contraction(ops.nullop, ops.add,
        frozenset(),
        (Variable('bias_b5', reals(2,)),
         Tensor(
          torch.tensor([-2.1787893772125244, 0.5684312582015991], dtype=torch.float32),  # noqa
          (),
          'real'),)),),))
    assert isinstance(x, (Gaussian, Contraction)), x.pretty()
Esempio n. 9
0
def eager_add(op, joint, other):
    # Update with a delayed gaussian random variable.
    subs = tuple(
        (d.name, d.point) for d in joint.deltas if d.name in other.inputs)
    if subs:
        other = Subs(other, subs)
    if joint.gaussian is not Number(0):
        other = joint.gaussian + other
    if not isinstance(other, Gaussian):
        return Joint(joint.deltas, joint.discrete) + other
    return Joint(joint.deltas, joint.discrete, other)
Esempio n. 10
0
def _simplify_integrate(fn, joint, integrand, reduced_vars):
    if any(d.name in reduced_vars for d in joint.deltas):
        subs = tuple(
            (d.name, d.point) for d in joint.deltas if d.name in reduced_vars)
        deltas = tuple(d for d in joint.deltas if d.name not in reduced_vars)
        log_measure = Joint(deltas, joint.discrete, joint.gaussian)
        integrand = Subs(integrand, subs)
        reduced_vars = reduced_vars - frozenset(name for name, point in subs)
        return Integrate(log_measure, integrand, reduced_vars)

    return fn(joint, integrand, reduced_vars)
Esempio n. 11
0
def eager_markov_product(sum_op, prod_op, trans, time, step, step_names):
    if step:
        result = sequential_sum_product(sum_op, prod_op, trans, time, dict(step))
    elif time.name in trans.inputs:
        result = trans.reduce(prod_op, time.name)
    elif prod_op is ops.add:
        result = trans * time.size
    elif prod_op is ops.mul:
        result = trans ** time.size
    else:
        raise NotImplementedError('https://github.com/pyro-ppl/funsor/issues/233')

    return Subs(result, step_names)
Esempio n. 12
0
 def eager_subs(self, subs):
     assert isinstance(subs, tuple)
     # Eagerly rename variables.
     rename = {k: v.name for k, v in subs if isinstance(v, Variable)}
     if not rename:
         return None
     step_names = frozenset(
         (k, rename.get(v, v)) for k, v in self.step_names.items())
     result = MarkovProduct(self.sum_op, self.prod_op, self.trans,
                            self.time, self.step, step_names)
     lazy = tuple((k, v) for k, v in subs if not isinstance(v, Variable))
     if lazy:
         result = Subs(result, lazy)
     return result
Esempio n. 13
0
def adjoint_subs_tensor(adj_redop, adj_binop, out_adj, arg, subs):

    assert all(isinstance(v, Funsor) for k, v in subs)

    # invert renaming
    renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable))
    out_adj = Subs(out_adj, renames)

    # inverting advanced indexing
    slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable))

    # TODO avoid reifying these zero/one tensors by using symbolic constants
    # ones for things that weren't sliced away
    ones_like_out = Subs(Tensor(torch.full_like(arg.data, ops.UNITS[adj_binop]),
                                arg.inputs.copy(), arg.output.dtype),
                         slices)
    arg_adj = adj_binop(out_adj, ones_like_out)

    # ones for things that were sliced away
    ones_like_arg = Tensor(torch.full_like(arg.data, ops.UNITS[adj_binop]),
                           arg.inputs.copy(), arg.output.dtype)
    arg_adj = _scatter(arg_adj, ones_like_arg, slices)

    return {arg: arg_adj}
Esempio n. 14
0
def adjoint_subs_gaussianmixture_gaussianmixture(adj_redop, adj_binop, out_adj, arg, subs):

    if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs):
        raise NotImplementedError("TODO implement adjoint for substitution into Gaussian real variable")

    # invert renaming
    renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable))
    out_adj = Subs(out_adj, renames)

    # inverting advanced indexing
    slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable))

    assert len(slices + renames) == len(subs)

    in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj.terms[0], arg.terms[0], subs)[arg.terms[0]]

    arg_int_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype != 'real')
    out_adj_int_inputs = OrderedDict((k, v) for k, v in out_adj.inputs.items() if v.dtype != 'real')

    arg_real_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype == 'real')

    align_inputs = OrderedDict((k, v) for k, v in out_adj.terms[1].inputs.items() if v.dtype != 'real')
    align_inputs.update(arg_real_inputs)
    out_adj_info_vec, out_adj_precision = align_gaussian(align_inputs, out_adj.terms[1])

    in_adj_info_vec = list(adjoint_ops(Subs, adj_redop, adj_binop,  # ops.add, ops.mul,
                                       Tensor(out_adj_info_vec, out_adj_int_inputs),
                                       Tensor(arg.terms[1].info_vec, arg_int_inputs),
                                       slices).values())[0]

    in_adj_precision = list(adjoint_ops(Subs, adj_redop, adj_binop,  # ops.add, ops.mul,
                                        Tensor(out_adj_precision, out_adj_int_inputs),
                                        Tensor(arg.terms[1].precision, arg_int_inputs),
                                        slices).values())[0]

    assert isinstance(in_adj_info_vec, Tensor)
    assert isinstance(in_adj_precision, Tensor)

    in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data, arg.inputs.copy())

    in_adj = in_adj_gaussian + in_adj_discrete
    return {arg: in_adj}
Esempio n. 15
0
    def _eager_subs_affine(self, subs, remaining_subs):
        # Extract an affine representation.
        affine = OrderedDict()
        for k, v in subs:
            const, coeffs = extract_affine(v)
            if (isinstance(const, Tensor) and all(
                    isinstance(coeff, Tensor)
                    for coeff, _ in coeffs.values())):
                affine[k] = const, coeffs
            else:
                remaining_subs += (k, v),
        if not affine:
            return reflect(Subs, self, remaining_subs)

        # Align integer dimensions.
        old_int_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, old_int_inputs),
            Tensor(self.precision, old_int_inputs)
        ]
        for const, coeffs in affine.values():
            tensors.append(const)
            tensors.extend(coeff for coeff, _ in coeffs.values())
        new_int_inputs, tensors = align_tensors(*tensors, expand=True)
        tensors = (Tensor(x, new_int_inputs) for x in tensors)
        old_info_vec = next(tensors).data
        old_precision = next(tensors).data
        for old_k, (const, coeffs) in affine.items():
            const = next(tensors)
            for new_k, (coeff, eqn) in coeffs.items():
                coeff = next(tensors)
                coeffs[new_k] = coeff, eqn
            affine[old_k] = const, coeffs
        batch_shape = old_info_vec.shape[:-1]

        # Align real dimensions.
        old_real_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype == 'real')
        new_real_inputs = old_real_inputs.copy()
        for old_k, (const, coeffs) in affine.items():
            del new_real_inputs[old_k]
            for new_k, (coeff, eqn) in coeffs.items():
                new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])]
                new_real_inputs[new_k] = Reals[new_shape]
        old_offsets, old_dim = _compute_offsets(old_real_inputs)
        new_offsets, new_dim = _compute_offsets(new_real_inputs)
        new_inputs = new_int_inputs.copy()
        new_inputs.update(new_real_inputs)

        # Construct a blockwise affine representation of the substitution.
        subs_vector = BlockVector(batch_shape + (old_dim, ))
        subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim))
        for old_k, old_offset in old_offsets.items():
            old_size = old_real_inputs[old_k].num_elements
            old_slice = slice(old_offset, old_offset + old_size)
            if old_k in new_real_inputs:
                new_offset = new_offsets[old_k]
                new_slice = slice(new_offset, new_offset + old_size)
                subs_matrix[..., new_slice, old_slice] = \
                    ops.new_eye(self.info_vec, batch_shape + (old_size,))
                continue
            const, coeffs = affine[old_k]
            old_shape = old_real_inputs[old_k].shape
            assert const.data.shape == batch_shape + old_shape
            subs_vector[..., old_slice] = const.data.reshape(batch_shape +
                                                             (old_size, ))
            for new_k, new_offset in new_offsets.items():
                if new_k in coeffs:
                    coeff, eqn = coeffs[new_k]
                    new_size = new_real_inputs[new_k].num_elements
                    new_slice = slice(new_offset, new_offset + new_size)
                    assert coeff.shape == new_real_inputs[
                        new_k].shape + old_shape
                    subs_matrix[..., new_slice, old_slice] = \
                        coeff.data.reshape(batch_shape + (new_size, old_size))
        subs_vector = subs_vector.as_tensor()
        subs_matrix = subs_matrix.as_tensor()
        subs_matrix_t = ops.transpose(subs_matrix, -1, -2)

        # Construct the new funsor. Suppose the old Gaussian funsor g has density
        #   g(x) = < x | i - 1/2 P x>
        # Now define a new funsor f by substituting x = A y + B:
        #   f(y) = g(A y + B)
        #        = < A y + B | i - 1/2 P (A y + B) >
        #        = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B >
        #        =: < y | i' - 1/2 P' y > + C
        # where  P' = At P A  and  i' = At (i - P B)  parametrize a new Gaussian
        # and  C = < B | i - 1/2 P B >  parametrize a new Tensor.
        precision = subs_matrix @ old_precision @ subs_matrix_t
        info_vec = _mv(subs_matrix,
                       old_info_vec - _mv(old_precision, subs_vector))
        const = _vv(subs_vector,
                    old_info_vec - 0.5 * _mv(old_precision, subs_vector))
        result = Gaussian(info_vec, precision, new_inputs) + Tensor(
            const, new_int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Esempio n. 16
0
    def _eager_subs_real(self, subs, remaining_subs):
        # Broadcast all component tensors.
        subs = OrderedDict(subs)
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = len(tensors[0].shape) - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not get_tracing_state():
                assert value.shape[-1] == self.inputs[k].num_elements
            values[k] = ops.expand(value, batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == Real
            return Subs(result, remaining_subs) if remaining_subs else result

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in a])
                for k1, i1 in slices if k1 in a
            ])
        prec_ab = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in a
            ])
        prec_bb = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in b
            ])
        info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a])
        info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b])
        value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b])
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:])
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in subs:
                inputs[k] = d
        result = Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Esempio n. 17
0
def test_bart(analytic_kl):
    global call_count
    call_count = 0

    with interpretation(reflect):
        q = Independent(
            Independent(
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Tensor(
                            torch.tensor(
                                [[
                                    -0.6077086925506592, -1.1546266078948975,
                                    -0.7021151781082153, -0.5303535461425781,
                                    -0.6365622282028198, -1.2423288822174072,
                                    -0.9941254258155823, -0.6287292242050171
                                ],
                                 [
                                     -0.6987162828445435, -1.0875964164733887,
                                     -0.7337473630905151, -0.4713417589664459,
                                     -0.6674002408981323, -1.2478348016738892,
                                     -0.8939017057418823, -0.5238542556762695
                                 ]],
                                dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                            ),
                            'real'),
                        Gaussian(
                            torch.tensor([
                                [[-0.3536059558391571], [-0.21779225766658783],
                                 [0.2840439975261688], [0.4531521499156952],
                                 [-0.1220812276005745], [-0.05519985035061836],
                                 [0.10932210087776184], [0.6656699776649475]],
                                [[-0.39107921719551086], [
                                    -0.20241987705230713
                                ], [0.2170514464378357], [0.4500560462474823],
                                 [0.27945515513420105], [-0.0490039587020874],
                                 [-0.06399798393249512], [0.846565842628479]]
                            ],
                                         dtype=torch.float32),  # noqa
                            torch.tensor([
                                [[[1.984686255455017]], [[0.6699360013008118]],
                                 [[1.6215802431106567]], [[2.372016668319702]],
                                 [[1.77385413646698]], [[0.526767373085022]],
                                 [[0.8722561597824097]], [[2.1879124641418457]]
                                 ],
                                [[[1.6996612548828125]], [[
                                    0.7535632252693176
                                ]], [[1.4946647882461548]],
                                 [[2.642792224884033]], [[1.7301604747772217]],
                                 [[0.5203893780708313]], [[1.055436372756958]],
                                 [[2.8370864391326904]]]
                            ],
                                         dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                                (
                                    'value_b1',
                                    reals(),
                                ),
                            )),
                    )),
                'gate_rate_b3',
                '_event_1_b2',
                'value_b1'),
            'gate_rate_t',
            'time_b4',
            'gate_rate_b3')
        p_prior = Contraction(
            ops.logaddexp,
            ops.add,
            frozenset({'state(time=1)_b11', 'state_b10'}),
            (
                MarkovProduct(
                    ops.logaddexp,
                    ops.add,
                    Contraction(
                        ops.nullop,
                        ops.add,
                        frozenset(),
                        (
                            Tensor(
                                torch.tensor(2.7672932147979736,
                                             dtype=torch.float32), (), 'real'),
                            Gaussian(
                                torch.tensor([-0.0, -0.0, 0.0, 0.0],
                                             dtype=torch.float32),
                                torch.tensor([[
                                    98.01002502441406, 0.0, -99.0000228881836,
                                    -0.0
                                ],
                                              [
                                                  0.0, 98.01002502441406, -0.0,
                                                  -99.0000228881836
                                              ],
                                              [
                                                  -99.0000228881836, -0.0,
                                                  100.0000228881836, 0.0
                                              ],
                                              [
                                                  -0.0, -99.0000228881836, 0.0,
                                                  100.0000228881836
                                              ]],
                                             dtype=torch.float32),  # noqa
                                (
                                    (
                                        'state_b7',
                                        reals(2, ),
                                    ),
                                    (
                                        'state(time=1)_b8',
                                        reals(2, ),
                                    ),
                                )),
                            Subs(
                                AffineNormal(
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.03488487750291824,
                                                0.07356668263673782,
                                                0.19946961104869843,
                                                0.5386509299278259,
                                                -0.708323061466217,
                                                0.24411526322364807,
                                                -0.20855577290058136,
                                                -0.2421337217092514
                                            ],
                                             [
                                                 0.41762110590934753,
                                                 0.5272183418273926,
                                                 -0.49835553765296936,
                                                 -0.0363837406039238,
                                                 -0.0005282597267068923,
                                                 0.2704298794269562,
                                                 -0.155222088098526,
                                                 -0.44802337884902954
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        (),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                -0.003566693514585495,
                                                -0.2848514914512634,
                                                0.037103548645973206,
                                                0.12648648023605347,
                                                -0.18501518666744232,
                                                -0.20899859070777893,
                                                0.04121830314397812,
                                                0.0054807960987091064
                                            ],
                                             [
                                                 0.0021788496524095535,
                                                 -0.18700894713401794,
                                                 0.08187370002269745,
                                                 0.13554862141609192,
                                                 -0.10477752983570099,
                                                 -0.20848378539085388,
                                                 -0.01393645629286766,
                                                 0.011670656502246857
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.5974780917167664,
                                                0.864071786403656,
                                                1.0236268043518066,
                                                0.7147538065910339,
                                                0.7423890233039856,
                                                0.9462157487869263,
                                                1.2132389545440674,
                                                1.0596832036972046
                                            ],
                                             [
                                                 0.5787821412086487,
                                                 0.9178534150123596,
                                                 0.9074794054031372,
                                                 0.6600189208984375,
                                                 0.8473222255706787,
                                                 0.8426999449729919,
                                                 1.194266438484192,
                                                 1.0471148490905762
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Variable('state(time=1)_b8', reals(2, )),
                                    Variable('gate_rate_b6', reals(8, ))),
                                ((
                                    'gate_rate_b6',
                                    Binary(
                                        ops.GetitemOp(0),
                                        Variable('gate_rate_t', reals(2, 8)),
                                        Variable('time_b9', bint(2))),
                                ), )),
                        )),
                    Variable('time_b9', bint(2)),
                    frozenset({('state_b7', 'state(time=1)_b8')}),
                    frozenset({('state(time=1)_b8', 'state(time=1)_b11'),
                               ('state_b7', 'state_b10')})),  # noqa
                Subs(
                    dist.MultivariateNormal(
                        Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32),
                               (), 'real'),
                        Tensor(
                            torch.tensor([[10.0, 0.0], [0.0, 10.0]],
                                         dtype=torch.float32),
                            (), 'real'), Variable('value_b5', reals(2, ))), ((
                                'value_b5',
                                Variable('state_b10', reals(2, )),
                            ), )),
            ))
        p_likelihood = Contraction(
            ops.add,
            ops.nullop,
            frozenset({'time_b17', 'destin_b16', 'origin_b15'}),
            (
                Contraction(
                    ops.logaddexp,
                    ops.add,
                    frozenset({'gated_b14'}),
                    (
                        dist.Categorical(
                            Binary(
                                ops.GetitemOp(0),
                                Binary(
                                    ops.GetitemOp(0),
                                    Subs(
                                        Function(
                                            unpack_gate_rate_0, reals(2, 2, 2),
                                            (Variable('gate_rate_b12',
                                                      reals(8, )), )),
                                        ((
                                            'gate_rate_b12',
                                            Binary(
                                                ops.GetitemOp(0),
                                                Variable(
                                                    'gate_rate_t', reals(2,
                                                                         8)),
                                                Variable('time_b17', bint(2))),
                                        ), )), Variable('origin_b15',
                                                        bint(2))),
                                Variable('destin_b16', bint(2))),
                            Variable('gated_b14', bint(2))),
                        Stack(
                            'gated_b14',
                            (
                                dist.Poisson(
                                    Binary(
                                        ops.GetitemOp(0),
                                        Binary(
                                            ops.GetitemOp(0),
                                            Subs(
                                                Function(
                                                    unpack_gate_rate_1,
                                                    reals(2, 2), (Variable(
                                                        'gate_rate_b13',
                                                        reals(8, )), )),
                                                ((
                                                    'gate_rate_b13',
                                                    Binary(
                                                        ops.GetitemOp(0),
                                                        Variable(
                                                            'gate_rate_t',
                                                            reals(2, 8)),
                                                        Variable(
                                                            'time_b17',
                                                            bint(2))),
                                                ), )),
                                            Variable('origin_b15', bint(2))),
                                        Variable('destin_b16', bint(2))),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                                dist.Delta(
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                            )),
                    )), ))

    if analytic_kl:
        exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
        with interpretation(monte_carlo):
            approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
        elbo = exact_part + approx_part
    else:
        p = p_prior + p_likelihood
        with interpretation(monte_carlo):
            elbo = Integrate(q, p - q, "gate_rate_t")

    assert isinstance(elbo, Tensor), elbo.pretty()
    assert call_count == 1
Esempio n. 18
0
def eager_integrate(delta, integrand, reduced_vars):
    assert delta.name in reduced_vars
    integrand = Subs(integrand, ((delta.name, delta.point), ))
    log_measure = delta.log_density
    reduced_vars -= frozenset([delta.name])
    return Integrate(log_measure, integrand, reduced_vars)
Esempio n. 19
0
def eager_add(op, joint, other):
    # Update with a delayed discrete random variable.
    subs = tuple((d.name, d.point) for d in joint.deltas if d in other.inputs)
    if subs:
        return joint + Subs(other, subs)
    return Joint(joint.deltas, joint.discrete + other, joint.gaussian)
Esempio n. 20
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple((k, materialize(to_funsor(v, self.inputs[k])))
                     for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Variables are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple((k, v) for k, v in subs
                          if not isinstance(v, (Number, Tensor, Variable)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        if not (var_subs or int_subs or real_subs):
            return reflect(Subs, self, lazy_subs)

        # First perform any variable substitutions.
        if var_subs:
            rename = {k: v.name for k, v in var_subs}
            inputs = OrderedDict(
                (rename.get(k, k), d) for k, d in self.inputs.items())
            if len(inputs) != len(self.inputs):
                raise ValueError("Variable substitution name conflict")
            var_result = Gaussian(self.loc, self.precision, inputs)
            return Subs(var_result, int_subs + real_subs + lazy_subs)

        # Next perform any integer substitution, i.e. slicing into a batch.
        if int_subs:
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            real_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
            tensors = [self.loc, self.precision]
            funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
            inputs = funsors[0].inputs.copy()
            inputs.update(real_inputs)
            int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
            return Subs(int_result, real_subs + lazy_subs)

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        real_subs = OrderedDict(subs)
        assert real_subs and not int_subs
        if all(k in real_subs for k, d in self.inputs.items()
               if d.dtype == 'real'):
            # Broadcast all component tensors.
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            tensors = [
                Tensor(self.loc, int_inputs),
                Tensor(self.precision, int_inputs)
            ]
            tensors.extend(real_subs.values())
            inputs, tensors = align_tensors(*tensors)
            batch_dim = tensors[0].dim() - 1
            batch_shape = broadcast_shape(*(x.shape[:batch_dim]
                                            for x in tensors))
            (loc, precision), values = tensors[:2], tensors[2:]

            # Form the concatenated value.
            offsets, event_size = _compute_offsets(self.inputs)
            value = BlockVector(batch_shape + (event_size, ))
            for k, value_k in zip(real_subs, values):
                offset = offsets[k]
                value_k = value_k.reshape(value_k.shape[:batch_dim] + (-1, ))
                if not torch._C._get_tracing_state():
                    assert value_k.size(-1) == self.inputs[k].num_elements
                value_k = value_k.expand(batch_shape + value_k.shape[-1:])
                value[...,
                      offset:offset + self.inputs[k].num_elements] = value_k
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = -0.5 * _vmv(precision, value - loc)
            result = Tensor(result, inputs)
            assert result.output == reals()
            return Subs(result, lazy_subs)

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353.
        # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
        raise NotImplementedError(
            'TODO implement partial substitution of real variables')
Esempio n. 21
0
def normalize_fuse_subs(arg, subs):
    # a(b)(c) -> a(b(c), c)
    arg_subs = tuple(arg.subs.items()) if isinstance(arg.subs,
                                                     OrderedDict) else arg.subs
    new_subs = subs + tuple((k, Subs(v, subs)) for k, v in arg_subs)
    return Subs(arg.arg, new_subs)
Esempio n. 22
0
    def eager_subs(self, subs):
        assert isinstance(subs, tuple)
        subs = tuple(
            (k, v if isinstance(v, (Variable, Slice)) else materialize(v))
            for k, v in subs if k in self.inputs)
        if not subs:
            return self

        # Constants and Variables are eagerly substituted;
        # everything else is lazily substituted.
        lazy_subs = tuple(
            (k, v) for k, v in subs
            if not isinstance(v, (Number, Tensor, Variable, Slice)))
        var_subs = tuple((k, v) for k, v in subs if isinstance(v, Variable))
        int_subs = tuple((k, v) for k, v in subs
                         if isinstance(v, (Number, Tensor, Slice))
                         if v.dtype != 'real')
        real_subs = tuple((k, v) for k, v in subs
                          if isinstance(v, (Number, Tensor))
                          if v.dtype == 'real')
        if not (var_subs or int_subs or real_subs):
            return reflect(Subs, self, lazy_subs)

        # First perform any variable substitutions.
        if var_subs:
            rename = {k: v.name for k, v in var_subs}
            inputs = OrderedDict(
                (rename.get(k, k), d) for k, d in self.inputs.items())
            if len(inputs) != len(self.inputs):
                raise ValueError("Variable substitution name conflict")
            var_result = Gaussian(self.info_vec, self.precision, inputs)
            return Subs(var_result, int_subs + real_subs + lazy_subs)

        # Next perform any integer substitution, i.e. slicing into a batch.
        if int_subs:
            int_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
            real_inputs = OrderedDict(
                (k, d) for k, d in self.inputs.items() if d.dtype == 'real')
            tensors = [self.info_vec, self.precision]
            funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
            inputs = funsors[0].inputs.copy()
            inputs.update(real_inputs)
            int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
            return Subs(int_result, real_subs + lazy_subs)

        # Broadcast all component tensors.
        real_subs = OrderedDict(subs)
        assert real_subs and not int_subs
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(real_subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = tensors[0].dim() - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(real_subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not torch._C._get_tracing_state():
                assert value.size(-1) == self.inputs[k].num_elements
            values[k] = value.expand(batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in real_subs for k, d in self.inputs.items()
               if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == reals()
            return Subs(result, lazy_subs)

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in real_subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_ab = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_bb = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in b
        ],
                            dim=-2)
        info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a],
                           dim=-1)
        info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b],
                           dim=-1)
        value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1)
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = prec_aa.expand(info_vec.shape + (-1, ))
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in real_subs:
                inputs[k] = d
        return Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)