Exemplo n.º 1
0
def test_counterfactual_query(intervene, observe, flip):
    # x -> y -> z -> w

    sites = ["x", "y", "z", "w"]
    observations = {"x": 1., "y": None, "z": 1., "w": 1.}
    interventions = {"x": None, "y": 0., "z": 2., "w": 1.}

    def model():
        x = _item(pyro.sample("x", dist.Normal(0, 1)))
        y = _item(pyro.sample("y", dist.Normal(x, 1)))
        z = _item(pyro.sample("z", dist.Normal(y, 1)))
        w = _item(pyro.sample("w", dist.Normal(z, 1)))
        return dict(x=x, y=y, z=z, w=w)

    if not flip:
        if intervene:
            model = poutine.do(model, data=interventions)
        if observe:
            model = poutine.condition(model, data=observations)
    elif flip and intervene and observe:
        model = poutine.do(poutine.condition(model, data=observations),
                           data=interventions)

    tr = poutine.trace(model).get_trace()
    actual_values = tr.nodes["_RETURN"]["value"]
    for name in sites:
        # case 1: purely observational query like poutine.condition
        if not intervene and observe:
            if observations[name] is not None:
                assert tr.nodes[name]['is_observed']
                assert_equal(observations[name], actual_values[name])
                assert_equal(observations[name], tr.nodes[name]['value'])
            if interventions[name] != observations[name]:
                assert_not_equal(interventions[name], actual_values[name])
        # case 2: purely interventional query like old poutine.do
        elif intervene and not observe:
            assert not tr.nodes[name]['is_observed']
            if interventions[name] is not None:
                assert_equal(interventions[name], actual_values[name])
            assert_not_equal(observations[name], tr.nodes[name]['value'])
            assert_not_equal(interventions[name], tr.nodes[name]['value'])
        # case 3: counterfactual query mixing intervention and observation
        elif intervene and observe:
            if observations[name] is not None:
                assert tr.nodes[name]['is_observed']
                assert_equal(observations[name], tr.nodes[name]['value'])
            if interventions[name] is not None:
                assert_equal(interventions[name], actual_values[name])
            if interventions[name] != observations[name]:
                assert_not_equal(interventions[name], tr.nodes[name]['value'])
Exemplo n.º 2
0
    def tl_reg(self, x, t, y, batch_size=None):
        # adapted from https://github.com/claudiashi57/dragonnet

        if not torch._C._get_tracing_state():
            assert x.dim() == 2 and x.size(-1) == self.feature_dim
        dataloader = [x] if batch_size is None else DataLoader(
            x, batch_size=batch_size)
        for x in dataloader:
            # x = self.whiten(x)
            with pyro.plate("num_particles", 1, dim=-2):
                with poutine.trace() as tr, poutine.block(hide=["y", "t"]):
                    self.guide(x)
                pred_t = poutine.replay(self.model.t_mean, tr.trace)(x)
                with poutine.do(data=dict(t=t)):
                    pred_y = poutine.replay(self.model.y_mean, tr.trace)(x)

        pred_t = pred_t.mean(0)  # probabilities
        pred_y = pred_y.mean(0)  # continuous outcome or probabilities
        # h = t / pred_t - (1 - t) / (1 - pred_t)
        h = t / pred_t.detach() - (1 - t) / (1 - pred_t.detach())

        if self.config["outcome_dist"] == 'bernoulli':
            y_pert = torch.sigmoid(logit_(p=pred_y) + self.model.epsilon * h)
            t_reg = torch.sum(-y * torch.log(y_pert) -
                              (1 - y) * torch.log(1 - y_pert))
        elif self.config["outcome_dist"] == 'normal':
            y_pert = pred_y + self.model.epsilon * h
            t_reg = torch.sum((y - y_pert)**2)
        return t_reg
Exemplo n.º 3
0
    def ite(self, x, ym, ys, num_samples=None, batch_size=None):
        r"""
        Computes Individual Treatment Effect for a batch of data ``x``.

        .. math::

            ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr]
                   - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]

        This has complexity ``O(len(x) * num_samples ** 2)``.

        :param ~torch.Tensor x: A batch of data.
        :param int num_samples: The number of monte carlo samples.
            Defaults to ``self.num_samples`` which defaults to ``100``.
        :param int batch_size: Batch size. Defaults to ``len(x)``.
        :return: A ``len(x)``-sized tensor of estimated effects.
        :rtype: ~torch.Tensor
        """
        if num_samples is None:
            num_samples = self.num_samples
        if not torch._C._get_tracing_state():
            assert x.dim() == 2 and x.size(-1) == self.feature_dim

        dataloader = [x] if batch_size is None else DataLoader(
            x, batch_size=batch_size)
        print("Evaluating {} minibatches".format(len(dataloader)))
        result_ite = []
        result_ate = []
        for x in dataloader:
            # x = self.whiten(x)
            with pyro.plate("num_particles", num_samples, dim=-2):
                with poutine.trace() as tr, poutine.block(hide=["y", "t"]):
                    self.guide(x)
                with poutine.do(data=dict(t=torch.zeros(()))):
                    y0 = poutine.replay(self.model.y_mean,
                                        tr.trace)(x) * ys + ym
                with poutine.do(data=dict(t=torch.ones(()))):
                    y1 = poutine.replay(self.model.y_mean,
                                        tr.trace)(x) * ys + ym
            ite = (y1 - y0).mean(0)
            ate = ite.mean()
            if not torch._C._get_tracing_state():
                print("batch ate = {:0.6g}".format(ate))
            result_ite.append(ite)
            result_ate.append(ate)
        return torch.cat(result_ite), result_ate[0]
Exemplo n.º 4
0
    def pol_att(self, x, y, t, e):

        num_samples = self.num_samples
        if not torch._C._get_tracing_state():
            assert x.dim() == 2 and x.size(-1) == self.feature_dim

        dataloader = [x]
        print("Evaluating {} minibatches".format(len(dataloader)))
        result_pol = []
        result_eatt = []
        for x in dataloader:
            # x = self.whiten(x)
            with pyro.plate("num_particles", num_samples, dim=-2):
                with poutine.trace() as tr, poutine.block(hide=["y", "t"]):
                    self.guide(x)
                with poutine.do(data=dict(t=torch.zeros(()))):
                    y0 = poutine.replay(self.model.y_mean, tr.trace)(x)
                with poutine.do(data=dict(t=torch.ones(()))):
                    y1 = poutine.replay(self.model.y_mean, tr.trace)(x)

            ite = (y1 - y0).mean(0)
            ite[t > 0] = -ite[t > 0]
            eatt = torch.abs(torch.mean(ite[(t + e) > 1]))
            pols = []
            for s in range(num_samples):
                pols.append(policy_val(ypred1=y1[s], ypred0=y0[s], y=y, t=t))

            pol = torch.stack(pols).mean(0)

            if not torch._C._get_tracing_state():
                print("batch eATT = {:0.6g}".format(eatt))
                print("batch RPOL = {:0.6g}".format(pol))
            result_pol.append(pol)
            result_eatt.append(eatt)

        return torch.stack(result_pol), torch.stack(result_eatt)
Exemplo n.º 5
0
    def test_do_propagation(self):
        pyro.clear_param_store()

        def model():
            z = pyro.sample("z", Normal(10.0 * torch.ones(1), 0.0001 * torch.ones(1)))
            latent_prob = torch.exp(z) / (torch.exp(z) + torch.ones(1))
            flip = pyro.sample("flip", Bernoulli(latent_prob))
            return flip

        sample_from_model = model()
        z_data = {"z": -10.0 * torch.ones(1)}
        # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0
        sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))()

        assert eq(sample_from_model, torch.ones(1))
        assert eq(sample_from_do_model, torch.zeros(1))
Exemplo n.º 6
0
    def test_do_propagation(self):
        pyro.clear_param_store()

        def model():
            z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1)))
            latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1))
            flip = pyro.sample("flip", Bernoulli(latent_prob))
            return flip

        sample_from_model = model()
        z_data = {"z": -10.0 * ng_ones(1)}
        # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0
        sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))()

        assert eq(sample_from_model, ng_ones(1))
        assert eq(sample_from_do_model, ng_zeros(1))
Exemplo n.º 7
0
def test_plate_duplication_smoke():
    def model(N):

        with pyro.plate("x_plate", N):
            z1 = pyro.sample(
                "z1", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
            z2 = pyro.sample(
                "z2", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
            return pyro.sample("x",
                               dist.MultivariateNormal(z1 + z2, torch.eye(2)))

    fix_z1 = torch.tensor([[-6.1258, -6.1524], [-4.1513, -4.3080]])

    obs_x = torch.tensor([[-6.1258, -6.1524], [-4.1513, -4.3080]])

    do_model = poutine.do(model, data={"z1": fix_z1})
    do_model = poutine.condition(do_model, data={"x": obs_x})
    do_auto = pyro.infer.autoguide.AutoMultivariateNormal(do_model)
    optim = pyro.optim.Adam({"lr": 0.05})

    svi = pyro.infer.SVI(do_model, do_auto, optim, pyro.infer.Trace_ELBO())
    svi.step(len(obs_x))
Exemplo n.º 8
0
 def test_do(self):
     data = {"latent2": torch.randn(2)}
     tr3 = poutine.trace(poutine.do(self.model, data=data)).get_trace()
     assert "latent2" not in tr3
Exemplo n.º 9
0
 def test_do(self):
     data = {"latent2": torch.randn(2)}
     tr3 = poutine.trace(poutine.do(self.model, data=data)).get_trace()
     assert "latent2" not in tr3