def test_counterfactual_query(intervene, observe, flip): # x -> y -> z -> w sites = ["x", "y", "z", "w"] observations = {"x": 1., "y": None, "z": 1., "w": 1.} interventions = {"x": None, "y": 0., "z": 2., "w": 1.} def model(): x = _item(pyro.sample("x", dist.Normal(0, 1))) y = _item(pyro.sample("y", dist.Normal(x, 1))) z = _item(pyro.sample("z", dist.Normal(y, 1))) w = _item(pyro.sample("w", dist.Normal(z, 1))) return dict(x=x, y=y, z=z, w=w) if not flip: if intervene: model = poutine.do(model, data=interventions) if observe: model = poutine.condition(model, data=observations) elif flip and intervene and observe: model = poutine.do(poutine.condition(model, data=observations), data=interventions) tr = poutine.trace(model).get_trace() actual_values = tr.nodes["_RETURN"]["value"] for name in sites: # case 1: purely observational query like poutine.condition if not intervene and observe: if observations[name] is not None: assert tr.nodes[name]['is_observed'] assert_equal(observations[name], actual_values[name]) assert_equal(observations[name], tr.nodes[name]['value']) if interventions[name] != observations[name]: assert_not_equal(interventions[name], actual_values[name]) # case 2: purely interventional query like old poutine.do elif intervene and not observe: assert not tr.nodes[name]['is_observed'] if interventions[name] is not None: assert_equal(interventions[name], actual_values[name]) assert_not_equal(observations[name], tr.nodes[name]['value']) assert_not_equal(interventions[name], tr.nodes[name]['value']) # case 3: counterfactual query mixing intervention and observation elif intervene and observe: if observations[name] is not None: assert tr.nodes[name]['is_observed'] assert_equal(observations[name], tr.nodes[name]['value']) if interventions[name] is not None: assert_equal(interventions[name], actual_values[name]) if interventions[name] != observations[name]: assert_not_equal(interventions[name], tr.nodes[name]['value'])
def tl_reg(self, x, t, y, batch_size=None): # adapted from https://github.com/claudiashi57/dragonnet if not torch._C._get_tracing_state(): assert x.dim() == 2 and x.size(-1) == self.feature_dim dataloader = [x] if batch_size is None else DataLoader( x, batch_size=batch_size) for x in dataloader: # x = self.whiten(x) with pyro.plate("num_particles", 1, dim=-2): with poutine.trace() as tr, poutine.block(hide=["y", "t"]): self.guide(x) pred_t = poutine.replay(self.model.t_mean, tr.trace)(x) with poutine.do(data=dict(t=t)): pred_y = poutine.replay(self.model.y_mean, tr.trace)(x) pred_t = pred_t.mean(0) # probabilities pred_y = pred_y.mean(0) # continuous outcome or probabilities # h = t / pred_t - (1 - t) / (1 - pred_t) h = t / pred_t.detach() - (1 - t) / (1 - pred_t.detach()) if self.config["outcome_dist"] == 'bernoulli': y_pert = torch.sigmoid(logit_(p=pred_y) + self.model.epsilon * h) t_reg = torch.sum(-y * torch.log(y_pert) - (1 - y) * torch.log(1 - y_pert)) elif self.config["outcome_dist"] == 'normal': y_pert = pred_y + self.model.epsilon * h t_reg = torch.sum((y - y_pert)**2) return t_reg
def ite(self, x, ym, ys, num_samples=None, batch_size=None): r""" Computes Individual Treatment Effect for a batch of data ``x``. .. math:: ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr] - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr] This has complexity ``O(len(x) * num_samples ** 2)``. :param ~torch.Tensor x: A batch of data. :param int num_samples: The number of monte carlo samples. Defaults to ``self.num_samples`` which defaults to ``100``. :param int batch_size: Batch size. Defaults to ``len(x)``. :return: A ``len(x)``-sized tensor of estimated effects. :rtype: ~torch.Tensor """ if num_samples is None: num_samples = self.num_samples if not torch._C._get_tracing_state(): assert x.dim() == 2 and x.size(-1) == self.feature_dim dataloader = [x] if batch_size is None else DataLoader( x, batch_size=batch_size) print("Evaluating {} minibatches".format(len(dataloader))) result_ite = [] result_ate = [] for x in dataloader: # x = self.whiten(x) with pyro.plate("num_particles", num_samples, dim=-2): with poutine.trace() as tr, poutine.block(hide=["y", "t"]): self.guide(x) with poutine.do(data=dict(t=torch.zeros(()))): y0 = poutine.replay(self.model.y_mean, tr.trace)(x) * ys + ym with poutine.do(data=dict(t=torch.ones(()))): y1 = poutine.replay(self.model.y_mean, tr.trace)(x) * ys + ym ite = (y1 - y0).mean(0) ate = ite.mean() if not torch._C._get_tracing_state(): print("batch ate = {:0.6g}".format(ate)) result_ite.append(ite) result_ate.append(ate) return torch.cat(result_ite), result_ate[0]
def pol_att(self, x, y, t, e): num_samples = self.num_samples if not torch._C._get_tracing_state(): assert x.dim() == 2 and x.size(-1) == self.feature_dim dataloader = [x] print("Evaluating {} minibatches".format(len(dataloader))) result_pol = [] result_eatt = [] for x in dataloader: # x = self.whiten(x) with pyro.plate("num_particles", num_samples, dim=-2): with poutine.trace() as tr, poutine.block(hide=["y", "t"]): self.guide(x) with poutine.do(data=dict(t=torch.zeros(()))): y0 = poutine.replay(self.model.y_mean, tr.trace)(x) with poutine.do(data=dict(t=torch.ones(()))): y1 = poutine.replay(self.model.y_mean, tr.trace)(x) ite = (y1 - y0).mean(0) ite[t > 0] = -ite[t > 0] eatt = torch.abs(torch.mean(ite[(t + e) > 1])) pols = [] for s in range(num_samples): pols.append(policy_val(ypred1=y1[s], ypred0=y0[s], y=y, t=t)) pol = torch.stack(pols).mean(0) if not torch._C._get_tracing_state(): print("batch eATT = {:0.6g}".format(eatt)) print("batch RPOL = {:0.6g}".format(pol)) result_pol.append(pol) result_eatt.append(eatt) return torch.stack(result_pol), torch.stack(result_eatt)
def test_do_propagation(self): pyro.clear_param_store() def model(): z = pyro.sample("z", Normal(10.0 * torch.ones(1), 0.0001 * torch.ones(1))) latent_prob = torch.exp(z) / (torch.exp(z) + torch.ones(1)) flip = pyro.sample("flip", Bernoulli(latent_prob)) return flip sample_from_model = model() z_data = {"z": -10.0 * torch.ones(1)} # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0 sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))() assert eq(sample_from_model, torch.ones(1)) assert eq(sample_from_do_model, torch.zeros(1))
def test_do_propagation(self): pyro.clear_param_store() def model(): z = pyro.sample("z", Normal(10.0 * ng_ones(1), 0.0001 * ng_ones(1))) latent_prob = torch.exp(z) / (torch.exp(z) + ng_ones(1)) flip = pyro.sample("flip", Bernoulli(latent_prob)) return flip sample_from_model = model() z_data = {"z": -10.0 * ng_ones(1)} # under model flip = 1 with high probability; so do indirect DO surgery to make flip = 0 sample_from_do_model = poutine.trace(poutine.do(model, data=z_data))() assert eq(sample_from_model, ng_ones(1)) assert eq(sample_from_do_model, ng_zeros(1))
def test_plate_duplication_smoke(): def model(N): with pyro.plate("x_plate", N): z1 = pyro.sample( "z1", dist.MultivariateNormal(torch.zeros(2), torch.eye(2))) z2 = pyro.sample( "z2", dist.MultivariateNormal(torch.zeros(2), torch.eye(2))) return pyro.sample("x", dist.MultivariateNormal(z1 + z2, torch.eye(2))) fix_z1 = torch.tensor([[-6.1258, -6.1524], [-4.1513, -4.3080]]) obs_x = torch.tensor([[-6.1258, -6.1524], [-4.1513, -4.3080]]) do_model = poutine.do(model, data={"z1": fix_z1}) do_model = poutine.condition(do_model, data={"x": obs_x}) do_auto = pyro.infer.autoguide.AutoMultivariateNormal(do_model) optim = pyro.optim.Adam({"lr": 0.05}) svi = pyro.infer.SVI(do_model, do_auto, optim, pyro.infer.Trace_ELBO()) svi.step(len(obs_x))
def test_do(self): data = {"latent2": torch.randn(2)} tr3 = poutine.trace(poutine.do(self.model, data=data)).get_trace() assert "latent2" not in tr3