Пример #1
0
def _multinomial_infer_value_domain(cls, **kwargs):
    instance = cls.dist_class(**{k: _dummy_tensor(domain) for k, domain in kwargs.items()}, validate_args=False)
    return reals(*instance.event_shape)
Пример #2
0
    def __call__(self):

        # calls pyro.param so that params are exposed and constraints applied
        # should not create any new torch.Tensors after __init__
        self.initialize_params()

        N_state = self.config["sizes"]["state"]

        # initialize gamma to uniform
        gamma = Tensor(
            torch.zeros((N_state, N_state)),
            OrderedDict([("y_prev", bint(N_state))]),
        )

        N_v = self.config["sizes"]["random"]
        N_c = self.config["sizes"]["group"]
        log_prob = []

        plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))]))

        # group-level random effects
        if self.config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = Variable("e_g", bint(N_v))
            e_g_dist = plate_g + dist.Categorical(**self.params["e_g"])(
                value=e_g)

            log_prob.append(e_g_dist)

            eps_g = (plate_g + self.params["eps_g"]["theta"])(e_g=e_g)

        elif self.config["group"]["random"] == "continuous":
            eps_g = Variable("eps_g", reals(N_state))
            eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])(
                value=eps_g)

            log_prob.append(eps_g_dist)
        else:
            eps_g = to_funsor(0.)

        N_s = self.config["sizes"]["individual"]

        plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))]))
        # individual-level random effects
        if self.config["individual"]["random"] == "discrete":
            # individual-level discrete effect
            e_i = Variable("e_i", bint(N_v))
            e_i_dist = plate_g + plate_i + dist.Categorical(
                **self.params["e_i"])(
                    value=e_i) * self.raggedness_masks["individual"](t=0)

            log_prob.append(e_i_dist)

            eps_i = (plate_i + plate_g +
                     self.params["eps_i"]["theta"](e_i=e_i))

        elif self.config["individual"]["random"] == "continuous":
            eps_i = Variable("eps_i", reals(N_state))
            eps_i_dist = plate_g + plate_i + dist.Normal(
                **self.params["eps_i"])(value=eps_i)

            log_prob.append(eps_i_dist)
        else:
            eps_i = to_funsor(0.)

        # add group-level and individual-level random effects to gamma
        gamma = gamma + eps_g + eps_i

        N_state = self.config["sizes"]["state"]

        # we've accounted for all effects, now actually compute gamma_y
        gamma_y = gamma(y_prev="y(t=1)")

        y = Variable("y", bint(N_state))
        y_dist = plate_g + plate_i + dist.Categorical(
            probs=gamma_y.exp() / gamma_y.exp().sum())(value=y)

        # observation 1: step size
        step_dist = plate_g + plate_i + dist.Gamma(
            **{k: v(y_curr=y)
               for k, v in self.params["step"].items()})(
                   value=self.observations["step"])

        # step size zero-inflation
        if self.config["zeroinflation"]:
            step_zi = dist.Categorical(
                probs=self.params["zi_step"]["zi_param"](y_curr=y))(
                    value="zi_step")
            step_zi_dist = plate_g + plate_i + dist.Delta(
                self.config["MISSING"])(value=self.observations["step"])
            step_dist = (step_zi + Stack("zi_step",
                                         (step_dist, step_zi_dist))).reduce(
                                             ops.logaddexp, "zi_step")

        # observation 2: step angle
        angle_dist = plate_g + plate_i + dist.VonMises(
            **{k: v(y_curr=y)
               for k, v in self.params["angle"].items()})(
                   value=self.observations["angle"])

        # observation 3: dive activity
        omega_dist = plate_g + plate_i + dist.Beta(
            **{k: v(y_curr=y)
               for k, v in self.params["omega"].items()})(
                   value=self.observations["omega"])

        # dive activity zero-inflation
        if self.config["zeroinflation"]:
            omega_zi = dist.Categorical(
                probs=self.params["zi_omega"]["zi_param"](y_curr=y))(
                    value="zi_omega")
            omega_zi_dist = plate_g + plate_i + dist.Delta(
                self.config["MISSING"])(value=self.observations["omega"])
            omega_dist = (omega_zi +
                          Stack("zi_omega",
                                (omega_dist, omega_zi_dist))).reduce(
                                    ops.logaddexp, "zi_omega")

        # finally, construct the term for parallel scan reduction
        hmm_factor = step_dist + angle_dist + omega_dist
        hmm_factor = hmm_factor * self.raggedness_masks["individual"]
        hmm_factor = hmm_factor * self.raggedness_masks["timestep"]
        # copy masking behavior of pyro.infer.TraceEnum_ELBO._compute_model_factors
        hmm_factor = hmm_factor + y_dist
        log_prob.insert(0, hmm_factor)

        return log_prob
Пример #3
0
    def forward(self, observations, add_bias=True):
        obs_dim = 2 * self.num_sensors
        bias_scale = self.log_bias_scale.exp()
        obs_noise = self.log_obs_noise.exp()
        trans_noise = self.log_trans_noise.exp()

        # bias distribution
        bias = Variable('bias', reals(obs_dim))
        assert not torch.isnan(bias_scale), "bias scales was nan"
        bias_dist = dist_to_funsor(
            dist.MultivariateNormal(
                torch.zeros(obs_dim),
                scale_tril=bias_scale *
                torch.eye(2 * self.num_sensors)))(value=bias)

        init_dist = torch.distributions.MultivariateNormal(torch.zeros(4),
                                                           scale_tril=100. *
                                                           torch.eye(4))
        self.init = dist_to_funsor(init_dist)(value="state")

        # hidden states
        prev = Variable("prev", reals(4))
        curr = Variable("curr", reals(4))
        self.trans_dist = f_dist.MultivariateNormal(
            loc=prev @ NCV_TRANSITION_MATRIX,
            scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky(),
            value=curr)

        state = Variable('state', reals(4))
        obs = Variable("obs", reals(obs_dim))
        observation_matrix = Tensor(
            torch.eye(4,
                      2).unsqueeze(-1).expand(-1, -1,
                                              self.num_sensors).reshape(4, -1))
        assert observation_matrix.output.shape == (
            4, obs_dim), observation_matrix.output.shape
        obs_loc = state @ observation_matrix
        if add_bias:
            obs_loc += bias
        self.observation_dist = f_dist.MultivariateNormal(
            loc=obs_loc, scale_tril=obs_noise * torch.eye(obs_dim), value=obs)

        logp = bias_dist
        curr = "state_init"
        logp += self.init(state=curr)
        for t, x in enumerate(observations):
            prev, curr = curr, f"state_{t}"
            logp += self.trans_dist(prev=prev, curr=curr)
            logp += self.observation_dist(state=curr, obs=x)
            # marginalize out previous state
            logp = logp.reduce(ops.logaddexp, prev)
        # marginalize out bias variable
        logp = logp.reduce(ops.logaddexp, "bias")

        # save posterior over the final state
        assert set(logp.inputs) == {f'state_{len(observations) - 1}'}
        posterior = funsor_to_mvn(logp, ndims=0)

        # marginalize out remaining variables
        logp = logp.reduce(ops.logaddexp)
        assert isinstance(logp, Tensor) and logp.shape == (), logp.pretty()
        return logp.data, posterior
Пример #4
0
        eval(expr)


@pytest.mark.parametrize("expr", [
    "Variable('x', reals()).log()",
    "Number(1) / Variable('x', reals())",
    "Variable('x', reals()) ** Number(2)",
    "Stack('t', (Number(1), Variable('x', reals()))).reduce(ops.logaddexp, 't')",
])
def test_eager_or_die_error(expr):
    with interpretation(eager_or_die):
        with pytest.raises(NotImplementedError):
            eval(expr)


@pytest.mark.parametrize('domain', [bint(3), reals()])
def test_variable(domain):
    x = Variable('x', domain)
    check_funsor(x, {'x': domain}, domain)
    assert Variable('x', domain) is x
    assert x('x') is x
    y = Variable('y', domain)
    assert x('y') is y
    assert x(x='y') is y
    assert x(x=y) is y
    x4 = Variable('x', bint(4))
    assert x4 is not x
    assert x4('x') is x4
    assert x(y=x4) is x

    xp1 = x + 1.
Пример #5
0
            if sampled_vars:
                assert dict(y.inputs) == dict(expected_inputs), sampled_vars
            else:
                assert y is x


@pytest.mark.parametrize('sample_inputs', [
    (),
    (('s', bint(3)), ),
    (('s', bint(3)), ('t', bint(4))),
],
                         ids=id_from_inputs)
@pytest.mark.parametrize('batch_inputs', [
    (),
    (('b', bint(2)), ),
    (('c', reals()), ),
    (('b', bint(2)), ('c', reals())),
],
                         ids=id_from_inputs)
@pytest.mark.parametrize('event_inputs', [
    (('e', reals()), ),
    (('e', reals()), ('f', reals(2))),
],
                         ids=id_from_inputs)
def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs):
    be_inputs = OrderedDict(batch_inputs + event_inputs)
    expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs)
    sample_inputs = OrderedDict(sample_inputs)
    batch_inputs = OrderedDict(batch_inputs)
    event_inputs = OrderedDict(event_inputs)
    x = random_gaussian(be_inputs)
Пример #6
0
def test_normal_defaults():
    loc = Variable('loc', reals())
    scale = Variable('scale', reals())
    value = Variable('value', reals())
    assert dist.Normal(loc, scale) is dist.Normal(loc, scale, value)
Пример #7
0
 def _rand(batch_shape, *event_shape):
     inputs = OrderedDict(zip("abcdef", map(bint, reversed(batch_shape))))
     return random_tensor(inputs, reals(*event_shape))
Пример #8
0
 def _fill_defaults(v, log_density=0, value='value'):
     v = to_funsor(v)
     log_density = to_funsor(log_density, reals())
     value = to_funsor(value, v.output)
     return v, log_density, value
Пример #9
0
 def _fill_defaults(loc, scale, value='value'):
     loc = to_funsor(loc, reals())
     scale = to_funsor(scale, reals())
     value = to_funsor(value, reals())
     return loc, scale, value
Пример #10
0
 def _fill_defaults(concentration1, concentration0, value='value'):
     concentration1 = to_funsor(concentration1, reals())
     concentration0 = to_funsor(concentration0, reals())
     value = to_funsor(value, reals())
     return concentration1, concentration0, value
Пример #11
0
 def _fill_defaults(total_count, probs, value='value'):
     total_count = to_funsor(total_count, reals())
     probs = to_funsor(probs)
     assert probs.dtype == "real"
     value = to_funsor(value, reals())
     return total_count, probs, value
Пример #12
0
 def _fill_defaults(logits, value='value'):
     logits = to_funsor(logits)
     assert logits.dtype == "real"
     value = to_funsor(value, reals())
     return logits, value
Пример #13
0
 def _fill_defaults(probs, value='value'):
     probs = to_funsor(probs)
     assert probs.dtype == "real"
     value = to_funsor(value, reals())
     return probs, value
Пример #14
0
    factors1 = partial_sum_product(sum_op, prod_op, factors, vars1, plates)
    factors2 = partial_sum_product(sum_op, prod_op, factors1, vars2, plates)
    actual = reduce(prod_op, factors2)

    expected = sum_product(sum_op, prod_op, factors, vars1 | vars2, plates)
    assert_close(actual, expected)


@pytest.mark.parametrize('num_steps', [None] + list(range(1, 13)))
@pytest.mark.parametrize('sum_op,prod_op,state_domain', [
    (ops.add, ops.mul, bint(2)),
    (ops.add, ops.mul, bint(3)),
    (ops.logaddexp, ops.add, bint(2)),
    (ops.logaddexp, ops.add, bint(3)),
    (ops.logaddexp, ops.add, reals()),
    (ops.logaddexp, ops.add, reals(2)),
],
                         ids=str)
@pytest.mark.parametrize('batch_inputs', [
    {},
    {
        "foo": bint(5)
    },
    {
        "foo": bint(2),
        "bar": bint(4)
    },
],
                         ids=lambda d: ",".join(d.keys()))
@pytest.mark.parametrize('impl', [
Пример #15
0
def test_beta_density(batch_shape, eager):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    @funsor.torch.function(reals(), reals(), reals(), reals())
    def beta(concentration1, concentration0, value):
        return torch.distributions.Beta(concentration1,
                                        concentration0).log_prob(value)

    check_funsor(beta, {
        'concentration1': reals(),
        'concentration0': reals(),
        'value': reals()
    }, reals())

    concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs)
    concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs)
    value = Tensor(torch.rand(batch_shape), inputs)
    expected = beta(concentration1, concentration0, value)
    check_funsor(expected, inputs, reals())

    d = Variable('value', reals())
    actual = dist.Beta(concentration1, concentration0, value) if eager else \
        dist.Beta(concentration1, concentration0, d)(value=value)
    check_funsor(actual, inputs, reals())
    assert_close(actual, expected)
Пример #16
0
    def _eager_subs_real(self, subs, remaining_subs):
        # Broadcast all component tensors.
        subs = OrderedDict(subs)
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = tensors[0].dim() - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not torch._C._get_tracing_state():
                assert value.size(-1) == self.inputs[k].num_elements
            values[k] = value.expand(batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == reals()
            return Subs(result, remaining_subs) if remaining_subs else result

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in a],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_ab = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in a
        ],
                            dim=-2)
        prec_bb = torch.cat([
            torch.cat([precision[..., i1, i2] for k2, i2 in slices if k2 in b],
                      dim=-1) for k1, i1 in slices if k1 in b
        ],
                            dim=-2)
        info_a = torch.cat([info_vec[..., i] for k, i in slices if k in a],
                           dim=-1)
        info_b = torch.cat([info_vec[..., i] for k, i in slices if k in b],
                           dim=-1)
        value_b = torch.cat([values[k] for k, i in slices if k in b], dim=-1)
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = prec_aa.expand(info_vec.shape + (-1, ))
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in subs:
                inputs[k] = d
        result = Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Пример #17
0
def test_delta_delta():
    v = Variable('v', reals(2))
    point = Tensor(torch.randn(2))
    log_density = Tensor(torch.tensor(0.5))
    d = dist.Delta(point, log_density, v)
    assert d is Delta('v', point, log_density)
Пример #18
0
    def _eager_subs_affine(self, subs, remaining_subs):
        # Extract an affine representation.
        affine = OrderedDict()
        for k, v in subs:
            const, coeffs = extract_affine(v)
            if (isinstance(const, Tensor) and all(
                    isinstance(coeff, Tensor)
                    for coeff, _ in coeffs.values())):
                affine[k] = const, coeffs
            else:
                remaining_subs += (k, v),
        if not affine:
            return reflect(Subs, self, remaining_subs)

        # Align integer dimensions.
        old_int_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, old_int_inputs),
            Tensor(self.precision, old_int_inputs)
        ]
        for const, coeffs in affine.values():
            tensors.append(const)
            tensors.extend(coeff for coeff, _ in coeffs.values())
        new_int_inputs, tensors = align_tensors(*tensors, expand=True)
        tensors = (Tensor(x, new_int_inputs) for x in tensors)
        old_info_vec = next(tensors).data
        old_precision = next(tensors).data
        for old_k, (const, coeffs) in affine.items():
            const = next(tensors)
            for new_k, (coeff, eqn) in coeffs.items():
                coeff = next(tensors)
                coeffs[new_k] = coeff, eqn
            affine[old_k] = const, coeffs
        batch_shape = old_info_vec.data.shape[:-1]

        # Align real dimensions.
        old_real_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype == 'real')
        new_real_inputs = old_real_inputs.copy()
        for old_k, (const, coeffs) in affine.items():
            del new_real_inputs[old_k]
            for new_k, (coeff, eqn) in coeffs.items():
                new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])]
                new_real_inputs[new_k] = reals(*new_shape)
        old_offsets, old_dim = _compute_offsets(old_real_inputs)
        new_offsets, new_dim = _compute_offsets(new_real_inputs)
        new_inputs = new_int_inputs.copy()
        new_inputs.update(new_real_inputs)

        # Construct a blockwise affine representation of the substitution.
        subs_vector = BlockVector(batch_shape + (old_dim, ))
        subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim))
        for old_k, old_offset in old_offsets.items():
            old_size = old_real_inputs[old_k].num_elements
            old_slice = slice(old_offset, old_offset + old_size)
            if old_k in new_real_inputs:
                new_offset = new_offsets[old_k]
                new_slice = slice(new_offset, new_offset + old_size)
                subs_matrix[..., new_slice, old_slice] = \
                    torch.eye(old_size).expand(batch_shape + (-1, -1))
                continue
            const, coeffs = affine[old_k]
            old_shape = old_real_inputs[old_k].shape
            assert const.data.shape == batch_shape + old_shape
            subs_vector[..., old_slice] = const.data.reshape(batch_shape +
                                                             (old_size, ))
            for new_k, new_offset in new_offsets.items():
                if new_k in coeffs:
                    coeff, eqn = coeffs[new_k]
                    new_size = new_real_inputs[new_k].num_elements
                    new_slice = slice(new_offset, new_offset + new_size)
                    assert coeff.shape == new_real_inputs[
                        new_k].shape + old_shape
                    subs_matrix[..., new_slice, old_slice] = \
                        coeff.data.reshape(batch_shape + (new_size, old_size))
        subs_vector = subs_vector.as_tensor()
        subs_matrix = subs_matrix.as_tensor()
        subs_matrix_t = subs_matrix.transpose(-1, -2)

        # Construct the new funsor. Suppose the old Gaussian funsor g has density
        #   g(x) = < x | i - 1/2 P x>
        # Now define a new funsor f by substituting x = A y + B:
        #   f(y) = g(A y + B)
        #        = < A y + B | i - 1/2 P (A y + B) >
        #        = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B >
        #        =: < y | i' - 1/2 P' y > + C
        # where  P' = At P A  and  i' = At (i - P B)  parametrize a new Gaussian
        # and  C = < B | i - 1/2 P B >  parametrize a new Tensor.
        precision = subs_matrix @ old_precision @ subs_matrix_t
        info_vec = _mv(subs_matrix,
                       old_info_vec - _mv(old_precision, subs_vector))
        const = _vv(subs_vector,
                    old_info_vec - 0.5 * _mv(old_precision, subs_vector))
        result = Gaussian(info_vec, precision, new_inputs) + Tensor(
            const, new_int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Пример #19
0
def test_mvn_defaults():
    loc = Variable('loc', reals(3))
    scale_tril = Variable('scale', reals(3, 3))
    value = Variable('value', reals(3))
    assert dist.MultivariateNormal(loc, scale_tril) is dist.MultivariateNormal(
        loc, scale_tril, value)
Пример #20
0
def test_eager_subs_variable():
    v = Variable('v', reals(3))
    point = Tensor(torch.randn(3))
    d = Delta('foo', v)
    assert d(v=point) is Delta('foo', point)
Пример #21
0
    def unscaled_sample(self, sampled_vars, sample_inputs):
        assert self.output == reals()
        sampled_vars = sampled_vars.intersection(self.inputs)
        if not sampled_vars:
            return self

        # Partition inputs into sample_inputs + batch_inputs + event_inputs.
        sample_inputs = OrderedDict(
            (k, d) for k, d in sample_inputs.items() if k not in self.inputs)
        sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
        batch_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if k not in sampled_vars)
        event_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if k in sampled_vars)
        be_inputs = batch_inputs.copy()
        be_inputs.update(event_inputs)
        sb_inputs = sample_inputs.copy()
        sb_inputs.update(batch_inputs)

        # Sample all variables in a single Categorical call.
        logits = align_tensor(be_inputs, self)
        batch_shape = logits.shape[:len(batch_inputs)]
        flat_logits = logits.reshape(batch_shape + (-1, ))
        sample_shape = tuple(d.dtype for d in sample_inputs.values())
        flat_sample = torch.distributions.Categorical(
            logits=flat_logits).sample(sample_shape)
        assert flat_sample.shape == sample_shape + batch_shape
        results = []
        mod_sample = flat_sample
        for name, domain in reversed(list(event_inputs.items())):
            size = domain.dtype
            point = Tensor(mod_sample % size, sb_inputs, size)
            mod_sample = mod_sample / size
            results.append(Delta(name, point))

        # Account for the log normalizer factor.
        # Derivation: Let f be a nonnormalized distribution (a funsor), and
        #   consider operations in linear space (source code is in log space).
        #   Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|.
        #                              f(x0) / |f|      # dice numerator
        #   Let g = delta(x=x0) |f| -----------------
        #                           detach(f(x0)/|f|)   # dice denominator
        #                       |detach(f)| f(x0)
        #         = delta(x=x0) -----------------  be a dice approximation of f.
        #                         detach(f(x0))
        #   Then g is an unbiased estimator of f in value and all derivatives.
        #   In the special case f = detach(f), we can simplify to
        #       g = delta(x=x0) |f|.
        if flat_logits.requires_grad:
            # Apply a dice factor to preserve differentiability.
            index = [
                torch.arange(n).reshape((n, ) + (1, ) *
                                        (flat_logits.dim() - i - 2))
                for i, n in enumerate(flat_logits.shape[:-1])
            ]
            index.append(flat_sample)
            log_prob = flat_logits[index]
            assert log_prob.shape == flat_sample.shape
            results.append(
                Tensor(
                    flat_logits.detach().logsumexp(-1) +
                    (log_prob - log_prob.detach()), sb_inputs))
        else:
            # This is the special case f = detach(f).
            results.append(Tensor(flat_logits.logsumexp(-1), batch_inputs))

        return reduce(ops.add, results)
Пример #22
0
def test_eager_subs_ground(log_density):
    point1 = Tensor(torch.randn(3))
    point2 = Tensor(torch.randn(3))
    d = Delta('foo', point1, log_density)
    check_funsor(d(foo=point1), {}, reals(), torch.tensor(float(log_density)))
    check_funsor(d(foo=point2), {}, reals(), torch.tensor(float('-inf')))
Пример #23
0
def test_to_data_error():
    with pytest.raises(ValueError):
        to_data(Variable('x', reals()))
    with pytest.raises(ValueError):
        to_data(Variable('y', bint(12)))
Пример #24
0
def test_transform_log(shape):
    point = Tensor(torch.randn(shape))
    x = Variable('x', reals(*shape))
    actual = Delta('y', point)(y=ops.log(x))
    expected = Delta('x', point.exp(), -point.sum())
    assert_close(actual, expected)
Пример #25
0
    x = Variable('x', reals())
    assert isinstance(x, Variable)

    y = Variable('y', reals())
    assert isinstance(y, Variable)

    result = eval(expr)
    assert isinstance(result, expected_type)
    assert result.is_affine


SUBS_TESTS = [
    ("(t * x)(i=1)", Contraction, {
        "j": bint(3),
        "x": reals()
    }),
    ("(t * x)(i=1, x=y)", Contraction, {
        "j": bint(3),
        "y": reals()
    }),
    ("(t * x + n)(x=y)", Contraction, {
        "y": reals(),
        "i": bint(2),
        "j": bint(3)
    }),
    ("(x + y)(y=z)", Contraction, {
        "x": reals(),
        "z": reals()
    }),
    ("(-x)(x=y+z)", Contraction, {
Пример #26
0
def test_binomial_density(batch_shape, eager):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))
    max_count = 10

    @funsor.torch.function(reals(), reals(), reals(), reals())
    def binomial(total_count, probs, value):
        return torch.distributions.Binomial(total_count, probs).log_prob(value)

    check_funsor(binomial, {
        'total_count': reals(),
        'probs': reals(),
        'value': reals()
    }, reals())

    value_data = random_tensor(inputs, bint(max_count)).data.float()
    total_count_data = value_data + random_tensor(
        inputs, bint(max_count)).data.float()
    value = Tensor(value_data, inputs)
    total_count = Tensor(total_count_data, inputs)
    probs = Tensor(torch.rand(batch_shape), inputs)
    expected = binomial(total_count, probs, value)
    check_funsor(expected, inputs, reals())

    m = Variable('value', reals())
    actual = dist.Binomial(total_count, probs, value) if eager else \
        dist.Binomial(total_count, probs, m)(value=value)
    check_funsor(actual, inputs, reals())
    assert_close(actual, expected)
Пример #27
0
    factors1 = partial_sum_product(sum_op, prod_op, factors, vars1, plates)
    factors2 = partial_sum_product(sum_op, prod_op, factors1, vars2, plates)
    actual = reduce(prod_op, factors2)

    expected = sum_product(sum_op, prod_op, factors, vars1 | vars2, plates)
    assert_close(actual, expected)


@pytest.mark.parametrize('num_steps', [None] + list(range(1, 13)))
@pytest.mark.parametrize('sum_op,prod_op,state_domain', [
    (ops.add, ops.mul, bint(2)),
    (ops.add, ops.mul, bint(3)),
    (ops.logaddexp, ops.add, bint(2)),
    (ops.logaddexp, ops.add, bint(3)),
    (ops.logaddexp, ops.add, reals()),
    (ops.logaddexp, ops.add, reals(2)),
],
                         ids=str)
@pytest.mark.parametrize('batch_inputs', [
    {},
    {
        "foo": bint(5)
    },
    {
        "foo": bint(2),
        "bar": bint(4)
    },
],
                         ids=lambda d: ",".join(d.keys()))
@pytest.mark.parametrize('impl', [
Пример #28
0
def test_categorical_defaults():
    probs = Variable('probs', reals(3))
    value = Variable('value', bint(3))
    assert dist.Categorical(probs) is dist.Categorical(probs, value)
Пример #29
0

@pytest.mark.parametrize('int_inputs', [
    {},
    {
        'i': bint(2)
    },
    {
        'i': bint(2),
        'j': bint(3)
    },
],
                         ids=id_from_inputs)
@pytest.mark.parametrize('real_inputs', [
    {
        'x': reals()
    },
    {
        'x': reals(4)
    },
    {
        'x': reals(2, 3)
    },
    {
        'x': reals(),
        'y': reals()
    },
    {
        'x': reals(2),
        'y': reals(3)
    },
Пример #30
0
def test_affine_subs():
    # This was recorded from test_pyro_convert.
    x = Subs(
        Gaussian(
            torch.tensor([
                1.3027106523513794, 1.4167094230651855, -0.9750942587852478,
                0.5321089029312134, -0.9039931297302246
            ],
                         dtype=torch.float32),  # noqa
            torch.tensor([[
                1.0199567079544067, 0.9840421676635742, -0.473368763923645,
                0.34206756949424744, -0.7562517523765564
            ],
                          [
                              0.9840421676635742, 1.511502742767334,
                              -1.7593903541564941, 0.6647964119911194,
                              -0.5119513273239136
                          ],
                          [
                              -0.4733688533306122, -1.7593903541564941,
                              3.2386727333068848, -0.9345928430557251,
                              -0.1534711718559265
                          ],
                          [
                              0.34206756949424744, 0.6647964119911194,
                              -0.9345928430557251, 0.3141004145145416,
                              -0.12399007380008698
                          ],
                          [
                              -0.7562517523765564, -0.5119513273239136,
                              -0.1534711718559265, -0.12399007380008698,
                              0.6450173854827881
                          ]],
                         dtype=torch.float32),  # noqa
            (
                (
                    'state_1_b6',
                    reals(3, ),
                ),
                (
                    'obs_b2',
                    reals(2, ),
                ),
            )),
        (
            (
                'obs_b2',
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Variable('bias_b5', reals(2, )),
                        Tensor(
                            torch.tensor(
                                [-2.1787893772125244, 0.5684312582015991],
                                dtype=torch.float32),  # noqa
                            (),
                            'real'),
                    )),
            ), ))
    assert isinstance(x, (Gaussian, Contraction)), x.pretty()