def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim): time = Variable("time", bint(num_steps)) bias = Variable("bias", reals(num_sensors, dim)) bias_dist = random_gaussian( OrderedDict([ ("bias", reals(num_sensors, dim)), ])) trans = random_gaussian( OrderedDict([ ("time", bint(num_steps)), ("x_prev", reals(dim)), ("x_curr", reals(dim)), ])) obs = random_gaussian( OrderedDict([ ("time", bint(num_steps)), ("x_curr", reals(dim)), ("bias", reals(dim)), ])) # Each time step only a single sensor observes x, # and each sensor has a different bias. sensor_id = Tensor(torch.arange(num_steps) % 2, OrderedDict(time=bint(num_steps)), dtype=2) with interpretation(eager_or_die): factor = trans + obs(bias=bias[sensor_id]) + bias_dist assert set(factor.inputs) == {"time", "bias", "x_prev", "x_curr"} result = sequential_sum_product(ops.logaddexp, ops.add, factor, time, {"x_prev": "x_curr"}) assert set(result.inputs) == {"bias", "x_prev", "x_curr"}
def log_prob(self, value): if self._validate_args: self._validate_sample(value) ndims = max(len(self.batch_shape), value.dim() - self.event_dim) time = Variable("time", Bint[self.event_shape[0]]) value = tensor_to_funsor(value, ("time", ), event_output=self.event_dim - 1, dtype=self.dtype) # Compare with pyro.distributions.hmm.DiscreteHMM.log_prob(). obs = self._obs(value=value) result = self._trans + obs result = sequential_sum_product(ops.logaddexp, ops.add, result, time, {"state": "state(time=1)"}) result = self._init + result.reduce(ops.logaddexp, "state(time=1)") result = result.reduce(ops.logaddexp, "state") result = funsor_to_tensor(result, ndims=ndims) return result
def test_mixed_sequential_sum_product(duration, num_segments): sum_op, prod_op = ops.logaddexp, ops.add time_var = Variable("time", bint(duration)) step = {"Px": "x"} trans_inputs = ((time_var.name, bint(duration)),) + \ tuple((k, bint(2)) for k in step.keys()) + \ tuple((v, bint(2)) for v in step.values()) trans = random_tensor(OrderedDict(trans_inputs)) expected = sequential_sum_product(sum_op, prod_op, trans, time_var, step) actual = mixed_sequential_sum_product(sum_op, prod_op, trans, time_var, step, num_segments=num_segments) assert_close(actual, expected)
def test_sequential_sum_product_bias_1(num_steps, dim): time = Variable("time", bint(num_steps)) bias_dist = random_gaussian(OrderedDict([ ("bias", reals(dim)), ])) trans = random_gaussian( OrderedDict([ ("time", bint(num_steps)), ("x_prev", reals(dim)), ("x_curr", reals(dim)), ])) obs = random_gaussian( OrderedDict([ ("time", bint(num_steps)), ("x_curr", reals(dim)), ("bias", reals(dim)), ])) factor = trans + obs + bias_dist assert set(factor.inputs) == {"time", "bias", "x_prev", "x_curr"} result = sequential_sum_product(ops.logaddexp, ops.add, factor, time, {"x_prev": "x_curr"}) assert set(result.inputs) == {"bias", "x_prev", "x_curr"}