Exemplo n.º 1
0
    def condition_causal(self,
                         their_tensors,
                         absent_tensors,
                         movie_inds,
                         num_samples=1000):

        their_cond = pyro.condition(model, data={"x": their_tensors})
        absent_cond = pyro.condition(model, data={"x": absent_tensors})

        their_y = []
        for _ in range(num_samples):
            their_y.append(
                torch.sum(their_cond(
                    self.step2_params)['y'][movie_inds]).item())

        absent_y = []
        for _ in range(num_samples):
            absent_y.append(
                torch.sum(absent_cond(
                    self.step2_params)['y'][movie_inds]).item())

        their_mean = np.mean(their_y)
        absent_mean = np.mean(absent_y)
        causal_effect_noconf = their_mean - absent_mean

        return causal_effect_noconf
Exemplo n.º 2
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):
        N, M = num_particles
        expanded_design = lexpand(design, N)

        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: lexpand(trace.nodes[l]["value"], M) for l in observation_labels}

        # Sample M times from q(theta | y, d) for each y
        reexpanded_design = lexpand(expanded_design, M)
        conditional_guide = pyro.condition(guide, data=y_dict)
        guide_trace = poutine.trace(conditional_guide).get_trace(
            y_dict, reexpanded_design, observation_labels, target_labels)
        theta_y_dict = {l: guide_trace.nodes[l]["value"] for l in target_labels}
        theta_y_dict.update(y_dict)
        guide_trace.compute_log_prob()

        # Re-run that through the model to compute the joint
        modelp = pyro.condition(model, data=theta_y_dict)
        model_trace = poutine.trace(modelp).get_trace(reexpanded_design)
        model_trace.compute_log_prob()

        terms = -sum(guide_trace.nodes[l]["log_prob"] for l in target_labels)
        terms += sum(model_trace.nodes[l]["log_prob"] for l in target_labels)
        terms += sum(model_trace.nodes[l]["log_prob"] for l in observation_labels)
        terms = -terms.logsumexp(0) + math.log(M)

        # At eval time, add p(y | theta, d) terms
        if evaluation:
            trace.compute_log_prob()
            terms += sum(trace.nodes[l]["log_prob"] for l in observation_labels)

        return _safe_mean_terms(terms)
Exemplo n.º 3
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):

        expanded_design = lexpand(design, num_particles)

        # Sample from p(y | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
        theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}

        # Run through q(y | d)
        qyd = pyro.condition(marginal_guide, data=y_dict)
        marginal_trace = poutine.trace(qyd).get_trace(
             expanded_design, observation_labels, target_labels)
        marginal_trace.compute_log_prob()

        # Run through q(y | theta, d)
        qythetad = pyro.condition(likelihood_guide, data=y_dict)
        cond_trace = poutine.trace(qythetad).get_trace(
                theta_dict, expanded_design, observation_labels, target_labels)
        cond_trace.compute_log_prob()
        terms = -sum(marginal_trace.nodes[l]["log_prob"] for l in observation_labels)

        # At evaluation time, use the right estimator, q(y | theta, d) - q(y | d)
        # At training time, use -q(y | theta, d) - q(y | d) so gradients go the same way
        if evaluation:
            terms += sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels)
        else:
            terms -= sum(cond_trace.nodes[l]["log_prob"] for l in observation_labels)

        return _safe_mean_terms(terms)
Exemplo n.º 4
0
 def svi_model(self, x, thickness, intensity):
     with pyro.plate('observations', x.shape[0]):
         pyro.condition(self.model,
                        data={
                            'x': x,
                            'thickness': thickness,
                            'intensity': intensity
                        })()
Exemplo n.º 5
0
 def svi_model(self, x, age, sex, ventricle_volume, brain_volume):
     with pyro.plate('observations', x.shape[0]):
         pyro.condition(self.model,
                        data={
                            'x': x,
                            'sex': sex,
                            'age': age,
                            'ventricle_volume': ventricle_volume,
                            'brain_volume': brain_volume
                        })()
def svi_test():
    rain_prob_prior = torch.tensor(.3)
    my_sprinkler_prob_prior = torch.tensor(.6)
    neighbor_sprinkler_prob_prior = torch.tensor(.2)
    conditioned_lawn = pyro.condition(lawn,
                                      data={
                                          "my_lawn": torch.tensor([1.]),
                                          "neighbor_lawn": torch.tensor([0.])
                                      })
    # guide = AutoGuide(lawn)
    # set up the optimizer
    adam_params = {"lr": 0.005, "betas": (0.90, 0.999)}
    optimizer = Adam(adam_params)

    # setup the inference algorithm
    svi = SVI(conditioned_lawn, lawn_guide, optimizer, loss=Trace_ELBO())

    n_steps = 1000
    # do gradient steps
    for step in range(n_steps):
        svi.step(rain_prob_prior, my_sprinkler_prob_prior,
                 neighbor_sprinkler_prob_prior)
        if step % 100 == 0:
            print("step: ", step)
            for p in [
                    'rain_prob', 'my_sprinkler_prob', 'neighbor_sprinkler_prob'
            ]:
                print(p, ": ", pyro.param(p).item())
def importance_empirical_test():
    conditioned_lawn = pyro.condition(lawn, data={"wet": torch.tensor([1.])})
    rain_post = pyro.infer.Importance(conditioned_lawn, num_samples=100)
    m = pyro.infer.EmpiricalMarginal(rain_post.run(),
                                     sites=["rain", "sprinkler"])
    print(m.log_prob(torch.tensor([1.])))
    print(m.log_prob(torch.tensor([0.])))
Exemplo n.º 8
0
 def predict_cond(self, batch, par=None, **kwargs):
     if par is "real":
         par = self.get_real_par(batch)
     else:
         par = par(batch)
     return pyro.condition(lambda batch: self.predict(batch, **kwargs),
                           data=par)(batch)
Exemplo n.º 9
0
    def likelihood(self, batch, par=None):
        if par is None:
            par = self.get_real_par(batch)

        return pyro.condition(
            lambda batch: self.forward(batch, mode="likelihood"),
            data=par)(batch)
Exemplo n.º 10
0
    def step2_train(self, x_data, y_data, params, num_samples=1000):
        print("Training Bayesian regression parameters...")
        pyro.clear_param_store()
        # Create a regression model
        conditioned_on_x_and_y = pyro.condition(model,
                                                data={
                                                    "x": x_data,
                                                    "y": y_data
                                                })

        svi = SVI(conditioned_on_x_and_y,
                  step2_guide,
                  self.step2_opt,
                  loss=Trace_ELBO(),
                  num_samples=num_samples)
        for step in range(self.step2_iters):
            loss = svi.step(params)
            if step % 100 == 0:
                print("[iteration %04d] loss: %.4f" %
                      (step + 1, loss / len(x_data)))

        updated_params = {k: v for k, v in params.items()}
        for name, value in pyro.get_param_store().items():
            print("Updating value of hypermeter: {}".format(name))
            updated_params[name] = value.detach()

        print("Training complete.")
        return updated_params
Exemplo n.º 11
0
    def step1_train(self, x_data, params):

        conditioned_on_x = pyro.condition(model, data={"x": x_data})
        svi = SVI(conditioned_on_x,
                  step1_guide,
                  self.step1_opt,
                  loss=Trace_ELBO())

        print("\n Training Z marginal and W parameter marginal...")

        # do gradient steps
        pyro.get_param_store().clear()
        for step in range(self.step1_iters):
            loss = svi.step(params)
            if step % 100 == 0:
                print("[iteration %04d] loss: %.4f" %
                      (step + 1, loss / len(x_data)))

        # grab the learned variational parameters

        updated_params = {k: v for k, v in params.items()}
        for name, value in pyro.get_param_store().items():
            print("Updating value of hypermeter{}".format(name))
            updated_params[name] = value.detach()

        return updated_params
Exemplo n.º 12
0
 def test_undo_uncondition(self):
     unconditioned_model = poutine.uncondition(self.model)
     reconditioned_model = pyro.condition(
         unconditioned_model, {"obs": torch.ones(2)}
     )
     reconditioned_trace = poutine.trace(reconditioned_model).get_trace()
     assert_equal(reconditioned_trace.nodes["obs"]["value"], torch.ones(2))
Exemplo n.º 13
0
    def get_logprobs(self, **obs):
        _required_data = ('x', 'thickness', 'intensity')
        assert set(obs.keys()) == set(_required_data)

        cond_model = pyro.condition(self.pyro_model.sample, data=obs)
        model_trace = pyro.poutine.trace(cond_model).get_trace(
            obs['x'].shape[0])
        model_trace.compute_log_prob()

        log_probs = {}
        nats_per_dim = {}
        for name, site in model_trace.nodes.items():
            if site["type"] == "sample" and site["is_observed"]:
                log_probs[name] = site["log_prob"].mean()
                log_prob_shape = site["log_prob"].shape
                value_shape = site["value"].shape
                if len(log_prob_shape) < len(value_shape):
                    dims = np.prod(value_shape[len(log_prob_shape):])
                else:
                    dims = 1.
                nats_per_dim[name] = -site["log_prob"].mean() / dims
                if self.hparams.validate:
                    print(
                        f'at site {name} with dim {dims} and nats: {nats_per_dim[name]} and logprob: {log_probs[name]}'
                    )
                    if torch.any(torch.isnan(nats_per_dim[name])):
                        raise ValueError('got nan')

        return log_probs, nats_per_dim
Exemplo n.º 14
0
    def _get_pyro_model(
        self,
        posterior: bool = True,
        num_observation: Optional[int] = None,
        observation: Optional[torch.Tensor] = None,
    ) -> Callable:
        """Get model function for use with Pyro

        If `num_observation` or `observation` is passed, the model is conditioned.

        Args:
            num_observation: Observation number
            observation: Instead of passing an observation number, an observation may be
                passed directly
            posterior: If False, will mask prior which will result in model useful
                for calculating log likelihoods instead of log posterior probabilities
        """
        assert not (num_observation is not None and observation is not None)

        if num_observation is not None:
            observation = self.get_observation(num_observation=num_observation)

        prior = self.get_prior()
        simulator = self.get_simulator()

        def model_fn():
            prior_ = pyro.poutine.mask(prior, torch.tensor(posterior))
            return simulator(prior_())

        if observation is not None:
            observation = self.unflatten_data(observation)
            return pyro.condition(model_fn, {"data": observation})
        else:
            return model_fn
Exemplo n.º 15
0
    def loss_fn(design, num_particles, **kwargs):

        try:
            pyro.module("T", T)
        except AssertionError:
            pass

        expanded_design = lexpand(design, num_particles)

        # Unshuffled data
        unshuffled_trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: unshuffled_trace.nodes[l]["value"] for l in observation_labels}

        # Shuffled data
        # Not actually shuffling, resimulate for safety
        conditional_model = pyro.condition(model, data=y_dict)
        shuffled_trace = poutine.trace(conditional_model).get_trace(expanded_design)

        T_joint = T(expanded_design, unshuffled_trace, observation_labels, target_labels)
        T_independent = T(expanded_design, shuffled_trace, observation_labels, target_labels)

        joint_expectation = T_joint.sum(0)/num_particles

        A = T_independent - math.log(num_particles)
        s, _ = torch.max(A, dim=0)
        independent_expectation = s + ewma_log((A - s).exp().sum(dim=0), s)

        loss = joint_expectation - independent_expectation
        # Switch sign, sum over batch dimensions for scalar loss
        agg_loss = -loss.sum()
        return agg_loss, loss
Exemplo n.º 16
0
def model_imp(num_iter, simulate):
    conditioned_alive = pyro.condition(simulate.model_next_move_1,
                                       data={"alive": 1})
    imp = pyro.infer.Importance(conditioned_alive, num_samples=num_iter).run()
    marginal_mv = pyro.infer.EmpiricalMarginal(imp, sites=["pl_mv"])

    return marginal_mv()
Exemplo n.º 17
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):

        expanded_design = lexpand(design, num_particles)

        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
        theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}

        # Run through q(theta | y, d)
        conditional_guide = pyro.condition(guide, data=theta_dict)
        cond_trace = poutine.trace(conditional_guide).get_trace(
            y_dict, expanded_design, observation_labels, target_labels)
        cond_trace.compute_log_prob()
        if evaluation and analytic_entropy:
            loss = mean_field_entropy(
                guide,
                [y_dict, expanded_design, observation_labels, target_labels],
                whitelist=target_labels).sum(0) / num_particles
            agg_loss = loss.sum()
        else:
            terms = -sum(cond_trace.nodes[l]["log_prob"]
                         for l in target_labels)
            agg_loss, loss = _safe_mean_terms(terms)

        return agg_loss, loss
Exemplo n.º 18
0
    def model(self, observations=None):
        if isinstance(observations, dict):
            data = observations[self._observation_name]
        elif observations is not None:
            data = observations
            observations = {
                self._observation_name:
                observations.view(-1, *self._data_space)
            }
        else:
            data = torch.zeros(1, *self._data_space)
            observations = {}
        data = data.view(data.shape[0], *self._data_space)
        for module in self._category.children():
            if isinstance(module, BaseModel):
                module.set_batching(data)

        min_depth = VAE_MIN_DEPTH if len(list(self.wiring_diagram)) == 1 else 0
        morphism = self._category(self.wiring_diagram, min_depth=min_depth)

        if observations is not None:
            score_morphism = pyro.condition(morphism, data=observations)
        else:
            score_morphism = morphism
        with pyro.plate('data', len(data)):
            with name_pop(name_stack=self._random_variable_names):
                output = score_morphism()
        return morphism, output
Exemplo n.º 19
0
def bvi(model, guide, claims, learning_rate=1e-5, num_samples=1):
    """perform blackbox mean field variational inference on simpleLCA.

    This methods take a simpleLCA model as input and perform blackbox variational
    inference, and returns a list of posterior distributions of hidden truth and source
    reliability. 

    Concretely, if s is a source then posterior(s) is the probability of s being honest.
    And if o is an object, or more correctly, is a random variable that has the support as
    the domain of an object, then posterior(o) is the distribution over these support.
    The underlying truth value of an object could be computed as the mode of this
    distribution.
    """
    data = make_observation_mapper(claims)
    conditioned_lca = pyro.condition(lca_model, data=data)
    pyro.clear_param_store()  # is it needed?
    svi = pyro.infer.SVI(model=conditioned_lca,
                         guide=lca_guide,
                         optim=pyro.optim.Adam({
                             "lr": learning_rate,
                             "betas": (0.90, 0.999)
                         }),
                         loss=pyro.infer.TraceGraph_ELBO(),
                         num_samples=num_samples)
    return svi
Exemplo n.º 20
0
 def update_noise_importance(self, observed_steady_state, initial_noise):
     observation_model = condition(self.noisy_model, observed_steady_state)
     posterior = self.infer(observation_model, initial_noise)
     updated_noise = {
         k: EmpiricalMarginal(posterior, sites=k)
         for k in initial_noise.keys()
     }
     return updated_noise
Exemplo n.º 21
0
def get_conditioned_model(yaml_section, model, device="cpu"):
    if yaml_section is None:
        return model
    conditions = {}
    for name, val in yaml_section.items():
        conditions[name] = yaml_params2._parse_val(name, val, device=device)
    cond_model = pyro.condition(model, conditions)
    return cond_model
Exemplo n.º 22
0
    def condition(self, condition_data: dict):
        """
        It conditions self.model with condition data
        Returns: Conditioned pyro model
        """

        conditioned_model = pyro.condition(self.model, condition_data)
        return conditioned_model
Exemplo n.º 23
0
def condCausal(their_tensors, absent_tensors, movie_inds):
    their_cond = pyro.condition(model, data={"x": their_tensors})
    absent_cond = pyro.condition(model, data={"x": absent_tensors})

    their_y = []
    for _ in range(1000):
        their_y.append(torch.sum(their_cond(p2)['y'][movie_inds]).item())

    absent_y = []
    for _ in range(1000):
        absent_y.append(torch.sum(absent_cond(p2)['y'][movie_inds]).item())

    their_mean = np.mean(their_y)
    absent_mean = np.mean(absent_y)
    causal_effect_noconf = their_mean - absent_mean

    return causal_effect_noconf
    def anwsers_given_traits(self, traits: np.ndarray):
        # No need to do inference, as we condition on input variables
        trait = torch.tensor(traits.reshape(1, self.t), dtype=torch.float32).expand(1_000, -1)

        conditional_model = pyro.condition(self.model, data={
            'trait': trait
        })

        return conditional_model(n_samples = 1_000)
Exemplo n.º 25
0
    def model_cond_sample(self, data_dict):
        # sample the graph given the conditioned variables in data_dict

        data_in = {}
        for item in data_dict:
            data_in[item] = data_dict[item]

        cond_model = pyro.condition(self.model_sample, data=data_in)
        return cond_model()
    def observation(self, action):
        action = action[0:len(action) - 1]
        action_data = {
            "action%d" % i: torch.tensor(float(action[i]))
            for i in range(len(action))
        }
        print("Action: ", action)

        #pyro importance sampling
        action_cond = pyro.condition(self.model, data=action_data)
        posterior = pyro.infer.Importance(action_cond, num_samples=800)
        posterior = posterior.run()
        marginal = posterior.marginal(
            sites=['norm'] +
            ['reward_agent%d' % rew_no for rew_no in range(self.agent_no)])
        empirical = marginal.empirical

        #calculate norm posterior
        inferred_norm = {}
        for j in range(len(self.agent_norm)):
            inferred_norm[j] = float(empirical['norm'].log_prob(j).exp())

        #calculate reward posterior
        inferred_reward = {}
        for k in range(self.agent_no):
            new_reward_dict = self.reward_dict.copy()
            agent_reward = []
            for rew in range(len(self.rew_prior['agent-0'])):
                rew_possibility = float(empirical['reward_agent%d' %
                                                  k].log_prob(rew).exp())
                new_reward_dict[rew] = [
                    i * rew_possibility for i in new_reward_dict[rew]
                ]
                if rew == 0:
                    agent_reward = new_reward_dict[0]
                else:
                    agent_reward = [
                        x + y
                        for x, y in zip(agent_reward, new_reward_dict[rew])
                    ]
            inferred_reward["agent-{}".format(k)] = agent_reward

        #update norm prior
        for i in range(len(self.agent_norm)):
            self.n_prior[i] = empirical['norm'].log_prob(i).exp()

        #update reward prior
        for agent_no in range(self.agent_no):
            for rew in range(len(self.rew_prior['agent-0'])):
                self.rew_prior['agent-%d' % agent_no][rew] = empirical[
                    'reward_agent%d' % agent_no].log_prob(rew).exp()

        print("Inferred norm: ", inferred_norm)
        print("Inferred reward: ", inferred_reward)
        return inferred_norm, inferred_reward
        """print("Observed Action: ")
Exemplo n.º 27
0
    def _get_log_likelihood(self, params):
        """Calculates the log-likelihood of the provided params.

        Use `pyro.condition` and `pyro.poutine.trace`.
        """
        # === Implement this
        conditioned_model = pyro.condition(self.model, data=params)
        trace = pyro.poutine.trace(conditioned_model).get_trace(
            *self._model_args, **self._model_kwargs)
        return trace.log_prob_sum()
Exemplo n.º 28
0
 def model(self, x: torch.Tensor):
     """
     :param x: Tensor in shape (batch_size, x_dim)
     :return:
     """
     pyro.module('decoder', self.decoder)
     batch_size = x.size()[0]
     observed = pyro.condition(lambda: self.generate(batch_size),
                               data={'obs': x})()
     return observed
Exemplo n.º 29
0
 def predict_cond(self, batch, par=None, **kwargs):
     """
             Par can either be the string "real" or a function that outputs the parameters when called with a batch of data.
     """
     if par is "real":
         par = self.get_real_par(batch)
     else:
         par = par(batch)
     return pyro.condition(lambda batch: self.predict(batch, **kwargs),
                           data=par)(batch)
Exemplo n.º 30
0
    def cond_predict(self, X_test, num_samples=1000):
        if self.test_params is None:
            self.infer_z(X_test)

        cond = pyro.condition(model, data={"x": X_test})
        predictions = np.zeros((X_test.shape[0]))
        for _ in range(num_samples):
            y_pred = cond(self.test_params)['y']
            predictions += y_pred.detach().numpy()

        return predictions / num_samples, self.test_params