Exemplo n.º 1
0
 def test_block_full(self):
     model_trace = poutine.trace(poutine.block(self.model)).get_trace()
     guide_trace = poutine.trace(poutine.block(self.guide)).get_trace()
     for name in model_trace.nodes.keys():
         assert model_trace.nodes[name]["type"] in ("args", "return")
     for name in guide_trace.nodes.keys():
         assert guide_trace.nodes[name]["type"] in ("args", "return")
Exemplo n.º 2
0
 def test_block_full_hide_expose(self):
     try:
         poutine.block(self.model,
                       hide=self.partial_sample_sites.keys(),
                       expose=self.partial_sample_sites.keys())()
         assert False
     except AssertionError:
         assert True
Exemplo n.º 3
0
 def test_block_full_expose(self):
     model_trace = poutine.trace(poutine.block(self.model,
                                               expose=self.model_sites)).get_trace()
     guide_trace = poutine.trace(poutine.block(self.guide,
                                               expose=self.guide_sites)).get_trace()
     for name in self.model_sites:
         assert name in model_trace
     for name in self.guide_sites:
         assert name in guide_trace
Exemplo n.º 4
0
def test_guide_list(auto_class):

    def model():
        pyro.sample("x", dist.Normal(0., 1.))
        pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5)))

    guide = AutoGuideList(model)
    guide.add(auto_class(poutine.block(model, expose=["x"]), prefix="auto_x"))
    guide.add(auto_class(poutine.block(model, expose=["y"]), prefix="auto_y"))
    guide()
Exemplo n.º 5
0
 def test_block_partial_expose(self):
     model_trace = poutine.trace(
         poutine.block(self.model, expose=self.partial_sample_sites.keys())).get_trace()
     guide_trace = poutine.trace(
         poutine.block(self.guide, expose=self.partial_sample_sites.keys())).get_trace()
     for name in self.full_sample_sites.keys():
         if name in self.partial_sample_sites:
             assert name in model_trace
             assert name in guide_trace
         else:
             assert name not in model_trace
             assert name not in guide_trace
Exemplo n.º 6
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model, default="parallel")
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        self._cond_indep_stacks = {}
        self._iaranges = {}
        for name, site in self.prototype_trace.nodes.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue
            if site["infer"].get("enumerate") != "parallel":
                raise NotImplementedError('Expected sample site "{}" to be discrete and '
                                          'configured for parallel enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(Dist.__name__))
            self._discrete_sites.append((site, Dist, params))

            # collect independence contexts
            self._cond_indep_stacks[name] = site["cond_indep_stack"]
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._iaranges[frame.name] = frame
                else:
                    raise NotImplementedError("AutoDiscreteParallel does not support pyro.irange")
Exemplo n.º 7
0
 def _gen_weighted_samples(self, *args, **kwargs):
     for tr, log_w in poutine.block(self.trace_dist._traces)(*args, **kwargs):
         if self.sites == "_RETURN":
             val = tr.nodes["_RETURN"]["value"]
         else:
             val = {name: tr.nodes[name]["value"]
                    for name in self.sites}
         yield (val, log_w)
Exemplo n.º 8
0
 def test_trace_data(self):
     tr1 = poutine.trace(
         poutine.block(self.model, expose_types=["sample"])).get_trace()
     tr2 = poutine.trace(
         poutine.condition(self.model, data=tr1)).get_trace()
     assert tr2.nodes["latent2"]["type"] == "sample" and \
         tr2.nodes["latent2"]["is_observed"]
     assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"]
Exemplo n.º 9
0
    def test_block_tutorial_case(self):
        model_trace = poutine.trace(self.model).get_trace()
        guide_trace = poutine.trace(
            poutine.block(self.guide, hide_types=["observe"])).get_trace()

        assert "latent1" in model_trace
        assert "latent1" in guide_trace
        assert "obs" in model_trace
        assert "obs" not in guide_trace
Exemplo n.º 10
0
Arquivo: jit.py Projeto: lewisKit/pyro
    def __call__(self, *args, **kwargs):

        # if first time
        if self.compiled is None:
            # param capture
            with poutine.block():
                with poutine.trace(param_only=True) as first_param_capture:
                    self.fn(*args, **kwargs)

            self._param_names = list(set(first_param_capture.trace.nodes.keys()))

            weakself = weakref.ref(self)

            @torch.jit.compile(**self._jit_options)
            def compiled(unconstrained_params, *args):
                self = weakself()
                constrained_params = {}
                for name, unconstrained_param in zip(self._param_names, unconstrained_params):
                    constrained_param = pyro.param(name)  # assume param has been initialized
                    assert constrained_param.unconstrained() is unconstrained_param
                    constrained_params[name] = constrained_param

                return poutine.replay(
                    self.fn, params=constrained_params)(*args, **kwargs)

            self.compiled = compiled

        param_list = [pyro.param(name).unconstrained()
                      for name in self._param_names]

        with poutine.block(hide=self._param_names):
            with poutine.trace(param_only=True) as param_capture:
                ret = self.compiled(param_list, *args, **kwargs)

        new_params = filter(lambda name: name not in self._param_names,
                            param_capture.trace.nodes.keys())

        for name in new_params:
            # enforce uniqueness
            if name not in self._param_names:
                self._param_names.append(name)

        return ret
Exemplo n.º 11
0
 def __call__(self, *args, **kwargs):
     traces, logits = [], []
     for tr, logit in poutine.block(self._traces)(*args, **kwargs):
         traces.append(tr)
         logits.append(logit)
     logits = torch.stack(logits).squeeze()
     logits -= util.log_sum_exp(logits)
     if not isinstance(logits, torch.autograd.Variable):
         logits = Variable(logits)
     ix = dist.categorical(logits=logits, one_hot=False)
     return traces[ix.data[0]]
Exemplo n.º 12
0
def test_discrete_parallel(continuous_class):
    K = 2
    data = torch.tensor([0., 1., 10., 11., 12.])

    def model(data):
        weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
        locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).independent(1))
        scale = pyro.sample('scale', dist.LogNormal(0, 1))

        with pyro.iarange('data', len(data)):
            weights = weights.expand(torch.Size((len(data),)) + weights.shape)
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

    guide = AutoGuideList(model)
    guide.add(continuous_class(poutine.block(model, hide=["assignment"])))
    guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))

    elbo = TraceEnum_ELBO(max_iarange_nesting=1)
    loss = elbo.loss_and_grads(model, guide, data)
    assert np.isfinite(loss), loss
Exemplo n.º 13
0
    def run(self, *args, **kwargs):
        """
        Calls `self._traces` to populate execution traces from a stochastic
        Pyro model.

        :param args: optional args taken by `self._traces`.
        :param kwargs: optional keywords args taken by `self._traces`.
        """
        self._init()
        for tr, logit in poutine.block(self._traces)(*args, **kwargs):
            self.exec_traces.append(tr)
            self.log_weights.append(logit)
        self._categorical = Categorical(logits=torch.tensor(self.log_weights))
        return self
Exemplo n.º 14
0
 def __init__(self, model, guide=None, num_samples=None):
     """
     Constructor. default to num_samples = 10, guide = model
     """
     super(Importance, self).__init__()
     if num_samples is None:
         num_samples = 10
         logger.warn("num_samples not provided, defaulting to {}".format(num_samples))
     if guide is None:
         # propose from the prior by making a guide from the model by hiding observes
         guide = poutine.block(model, hide_types=["observe"])
     self.num_samples = num_samples
     self.model = model
     self.guide = guide
Exemplo n.º 15
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        self.prototype_trace = poutine.block(poutine.trace(self.model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._iaranges = {}
        for name, site in self.prototype_trace.nodes.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._iaranges[frame.name] = frame
                else:
                    raise NotImplementedError("AutoGuideList does not support pyro.irange")
Exemplo n.º 16
0
    def run(self, *args, **kwargs):
        """
        Calls `self._traces` to populate execution traces from a stochastic
        Pyro model.

        :param args: optional args taken by `self._traces`.
        :param kwargs: optional keywords args taken by `self._traces`.
        """
        self._reset()
        with poutine.block():
            for i, vals in enumerate(self._traces(*args, **kwargs)):
                if len(vals) == 2:
                    chain_id = 0
                    tr, logit = vals
                else:
                    tr, logit, chain_id = vals
                    assert chain_id < self.num_chains
                self.exec_traces.append(tr)
                self.log_weights.append(logit)
                self.chain_ids.append(chain_id)
                self._idx_by_chain[chain_id].append(i)
        self._categorical = Categorical(logits=torch.tensor(self.log_weights))
        return self
Exemplo n.º 17
0
    def predict(self, x, trg_y):
        """Predict labels.

        Args:
          x (torch.utils.data.DataLoader):
          trg_y (np.array): array for storing the predicted labels

        Returns:
          float: loss

        """
        # resize `self._test_wbench` if necessary
        n_instances = x[0].shape[0]
        self._resize_wbench(n_instances)
        self._test_wbench *= 0
        # print("self._test_wbench:", repr(self._test_wbench))
        with poutine.block():
            with torch.no_grad():
                for wbench_i in self._test_wbench:
                    wbench_i[:n_instances] = self.guide(*x)
        mean = np.mean(self._test_wbench, axis=0)
        trg_y[:] = np.argmax(mean[:n_instances], axis=-1)
        return trg_y
Exemplo n.º 18
0
def main(_argv):
    transition_alphas = torch.tensor([[10., 90.],
                                      [90., 10.]])
    emission_alphas = torch.tensor([[[30., 20., 5.]],
                                    [[5., 10., 100.]]])
    lengths = torch.randint(10, 30, (10000,))
    trace = poutine.trace(model).get_trace(transition_alphas, emission_alphas, lengths)
    obs_sequences = [site['value'] for name, site in
                     trace.nodes.items() if name.startswith("element_")]
    obs_sequences = torch.stack(obs_sequences, dim=-2)
    guide = AutoDelta(poutine.block(model, hide_fn=lambda site: site['name'].startswith('state')),
                      init_loc_fn=init_to_sample)
    svi = SVI(model, guide, Adam(dict(lr=0.1)), JitTraceEnum_ELBO())
    total = 1000
    with tqdm.trange(total) as t:
        for i in t:
            loss = svi.step(0.5 * torch.ones((2, 2), dtype=torch.float),
                            0.3 * torch.ones((2, 1, 3), dtype=torch.float),
                            lengths, obs_sequences)
            t.set_description_str(f"SVI ({i}/{total}): {loss}")
    median = guide.median()
    print("Transition probs: ", median['transition_probs'].detach().numpy())
    print("Emission probs: ", median['emission_probs'].squeeze().detach().numpy())
Exemplo n.º 19
0
 def posterior_entropy(y_dist, design):
     # Important that y_dist is sampled *within* the function
     y = pyro.sample("conditioning_y", y_dist)
     y_dict = {
         label: y[i, ...]
         for i, label in enumerate(observation_labels)
     }
     conditioned_model = pyro.condition(model, data=y_dict)
     # Here just using SVI to run the MAP optimization
     guide.train()
     SVI(conditioned_model,
         guide=guide,
         loss=loss,
         optim=optim,
         num_steps=num_steps,
         num_samples=1).run(design)
     # Recover the entropy
     with poutine.block():
         final_loss = loss(conditioned_model, guide, design)
         guide.finalize(final_loss, target_labels)
         entropy = mean_field_entropy(guide, [design],
                                      whitelist=target_labels)
     return entropy
Exemplo n.º 20
0
def relbo(model, guide, *args, **kwargs):

    approximation = kwargs.pop('approximation', None)
    # Run the guide with the arguments passed to SVI.step() and trace the execution,
    # i.e. record all the calls to Pyro primitives like sample() and param().
    guide_trace = trace(guide).get_trace(*args, **kwargs)
    # Now run the model with the same arguments and trace the execution. Because
    # model is being run with replay, whenever we encounter a sample site in the
    # model, instead of sampling from the corresponding distribution in the model,
    # we instead reuse the corresponding sample from the guide. In probabilistic
    # terms, this means our loss is constructed as an expectation w.r.t. the joint
    # distribution defined by the guide.
    model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs)
    approximation_trace = trace(
        replay(block(approximation, expose=["obs"]),
               guide_trace)).get_trace(*args, **kwargs)
    # We will accumulate the various terms of the ELBO in `elbo`.
    elbo = 0.
    # Loop over all the sample sites in the model and add the corresponding
    # log p(z) term to the ELBO. Note that this will also include any observed
    # data, i.e. sample sites with the keyword `obs=...`.
    elbo = elbo + model_trace.log_prob_sum()
    # Loop over all the sample sites in the guide and add the corresponding
    # -log q(z) term to the ELBO.
    elbo = elbo - guide_trace.log_prob_sum()
    elbo = elbo - approximation_trace.log_prob_sum()

    # Return (-elbo) since by convention we do gradient descent on a loss and
    # the ELBO is a lower bound that needs to be maximized.
    if elbo < 10e-8 and PRINT_TRACES:
        print('Guide trace')
        print(guide_trace.log_prob_sum())
        print('Model trace')
        print(model_trace.log_prob_sum())
        print('Approximation trace')
        print(approximation_trace.log_prob_sum())
    return -elbo
Exemplo n.º 21
0
    def pol_att(self, x, y, t, e):

        num_samples = self.num_samples
        if not torch._C._get_tracing_state():
            assert x.dim() == 2 and x.size(-1) == self.feature_dim

        dataloader = [x]
        print("Evaluating {} minibatches".format(len(dataloader)))
        result_pol = []
        result_eatt = []
        for x in dataloader:
            # x = self.whiten(x)
            with pyro.plate("num_particles", num_samples, dim=-2):
                with poutine.trace() as tr, poutine.block(hide=["y", "t"]):
                    self.guide(x)
                with poutine.do(data=dict(t=torch.zeros(()))):
                    y0 = poutine.replay(self.model.y_mean, tr.trace)(x)
                with poutine.do(data=dict(t=torch.ones(()))):
                    y1 = poutine.replay(self.model.y_mean, tr.trace)(x)

            ite = (y1 - y0).mean(0)
            ite[t > 0] = -ite[t > 0]
            eatt = torch.abs(torch.mean(ite[(t + e) > 1]))
            pols = []
            for s in range(num_samples):
                pols.append(policy_val(ypred1=y1[s], ypred0=y0[s], y=y, t=t))

            pol = torch.stack(pols).mean(0)

            if not torch._C._get_tracing_state():
                print("batch eATT = {:0.6g}".format(eatt))
                print("batch RPOL = {:0.6g}".format(pol))
            result_pol.append(pol)
            result_eatt.append(eatt)

        return torch.stack(result_pol), torch.stack(result_eatt)
Exemplo n.º 22
0
Arquivo: elbo.py Projeto: zyxue/pyro
    def _guess_max_plate_nesting(self, model, guide, *args, **kwargs):
        """
        Guesses max_plate_nesting by running the (model,guide) pair once
        without enumeration. This optimistically assumes static model
        structure.
        """
        # Ignore validation to allow model-enumerated sites absent from the guide.
        with poutine.block():
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(
                poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        sites = [site
                 for trace in (model_trace, guide_trace)
                 for site in trace.nodes.values()
                 if site["type"] == "sample"]

        # Validate shapes now, since shape constraints will be weaker once
        # max_plate_nesting is changed from float('inf') to some finite value.
        # Here we know the traces are not enumerated, but later we'll need to
        # allow broadcasting of dims to the left of max_plate_nesting.
        if is_validation_enabled():
            guide_trace.compute_log_prob()
            model_trace.compute_log_prob()
            for site in sites:
                check_site_shape(site, max_plate_nesting=float('inf'))

        dims = [frame.dim
                for site in sites
                for frame in site["cond_indep_stack"]
                if frame.vectorized]
        self.max_plate_nesting = -min(dims) if dims else 0
        if self.vectorize_particles and self.num_particles > 1:
            self.max_plate_nesting += 1
        logging.info('Guessed max_plate_nesting = {}'.format(self.max_plate_nesting))
Exemplo n.º 23
0
def belief_policy_model_guide(belief,
                              t,
                              discount=1.0,
                              discount_factor=0.95,
                              max_depth=10):
    # prior weights is uniform
    weights = pyro.param("action_weights",
                         torch.ones(len(actions)),
                         constraint=dist.constraints.simplex)
    state = states[pyro.sample("s%d" % t, belief)]
    history = ""  # we can start from empty history
    # This is just to generate the other variables
    with poutine.block(hide=["a%d" % t]):
        # Need to hide 'at' generated by the policy_model;
        # we don't care about it; because that's what we
        # are inferring -- we are calling pyro.sample("a%d" % t ..) next.
        history_policy_model(state,
                             history,
                             t,
                             discount=discount,
                             discount_factor=discount_factor,
                             max_depth=max_depth)
    # We eventually generate actions based on the weights
    action = pyro.sample("a%d" % t, dist.Categorical(weights))
Exemplo n.º 24
0
def align_samples(samples, model, particle_dim):
    """
    Unsqueeze stacked samples such that their particle dim all aligns.
    This traces ``model`` to determine the ``event_dim`` of each site.
    """
    assert particle_dim < 0

    sample = {name: value[0] for name, value in samples.items()}
    with poutine.block(), poutine.trace() as tr, poutine.condition(
            data=sample):
        model()

    samples = samples.copy()
    for name, value in samples.items():
        event_dim = tr.trace.nodes[name]["fn"].event_dim
        pad = event_dim - particle_dim - value.dim()
        if pad < 0:
            raise ValueError(
                "Cannot align samples, try moving particle_dim left")
        if pad > 0:
            shape = value.shape[:1] + (1, ) * pad + value.shape[1:]
            samples[name] = value.reshape(shape)

    return samples
Exemplo n.º 25
0
def initialize(seed, model, data):
    global global_guide, svi
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    exposed_params = []
    # set the parameters inferred through the guide based on the kind of data
    if 'gr' in mtype:
        if dtype == 'norm':
            exposed_params = ['weights', 'concentration']
        elif dtype == 'raw':
            exposed_params = ['weights', 'alpha', 'beta']
    elif 'dim' in mtype:
        if dtype == 'norm':
            exposed_params = [
                'topic_weights', 'topic_concentration', 'participant_topics'
            ]
        elif dtype == 'raw':
            exposed_params = [
                'topic_weights', 'topic_a', 'topic_b', 'participant_topics'
            ]

    global_guide = AutoDelta(poutine.block(model, expose=exposed_params))
    svi = SVI(model, global_guide, optim, loss=elbo)
    return svi.loss(model, global_guide, data)
Exemplo n.º 26
0
def main(args):
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # load data
    if args.dataset == "dipper":
        capture_history_file = os.path.dirname(
            os.path.abspath(__file__)) + '/dipper_capture_history.csv'
    elif args.dataset == "vole":
        capture_history_file = os.path.dirname(
            os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv'
    else:
        raise ValueError("Available datasets are \'dipper\' and \'vole\'.")

    capture_history = torch.tensor(
        np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:]
    N, T = capture_history.shape
    print(
        "Loaded {} capture history for {} individuals collected over {} time periods."
        .format(args.dataset, N, T))

    if args.dataset == "dipper" and args.model in ["4", "5"]:
        sex_file = os.path.dirname(
            os.path.abspath(__file__)) + '/dipper_sex.csv'
        sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:,
                                                                           1]
        print("Loaded dipper sex data.")
    elif args.dataset == "vole" and args.model in ["4", "5"]:
        raise ValueError(
            "Cannot run model_{} on meadow voles data, since we lack sex " +
            "information for these animals.".format(args.model))
    else:
        sex = None

    model = models[args.model]

    # we use poutine.block to only expose the continuous latent variables
    # in the models to AutoDiagonalNormal (all of which begin with 'phi'
    # or 'rho')
    def expose_fn(msg):
        return msg["name"][0:3] in ['phi', 'rho']

    # we use a mean field diagonal normal variational distributions (i.e. guide)
    # for the continuous latent variables.
    guide = AutoDiagonalNormal(poutine.block(model, expose_fn=expose_fn))

    # since we enumerate the discrete random variables,
    # we need to use TraceEnum_ELBO or TraceTMC_ELBO.
    optim = Adam({'lr': args.learning_rate})
    if args.tmc:
        elbo = TraceTMC_ELBO(max_plate_nesting=1)
        tmc_model = poutine.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                         )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        elbo = TraceEnum_ELBO(max_plate_nesting=1,
                              num_particles=20,
                              vectorize_particles=True)
        svi = SVI(model, guide, optim, elbo)

    losses = []

    print(
        "Beginning training of model_{} with Stochastic Variational Inference."
        .format(args.model))

    for step in range(args.num_steps):
        loss = svi.step(capture_history, sex)
        losses.append(loss)
        if step % 20 == 0 and step > 0 or step == args.num_steps - 1:
            print("[iteration %03d] loss: %.3f" %
                  (step, np.mean(losses[-20:])))

    # evaluate final trained model
    elbo_eval = TraceEnum_ELBO(max_plate_nesting=1,
                               num_particles=2000,
                               vectorize_particles=True)
    svi_eval = SVI(model, guide, optim, elbo_eval)
    print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex))
Exemplo n.º 27
0
def auto_guide_list_x(model):
    guide = AutoGuideList(model)
    guide.add(AutoDelta(poutine.block(model, expose=["x"])))
    guide.add(AutoDiagonalNormal(poutine.block(model, hide=["x"])))
    return guide
Exemplo n.º 28
0
    #   for x in batch:
    #     ret.append(inner(observed[x], start, 0, 0, prefix=str(x))[0])

    #   return ret
    return inner(observed, start, 0, 0)


if __name__ == '__main__':
    observed = ["the dog bit the man".split()]

    nonterminals = ["S", "NP", "VP"]
    preterminals = ["DT", "NN", "VB"]

    model = partial(model,
                    start="S",
                    nonterminals=nonterminals,
                    productions=[("NP", "VP"), ("DT", "NN"), ("VB", "NN")],
                    preterminals=preterminals,
                    terminals=["the", "dog", "man", "bit"])

    expose_params = ["binary_%s" % nonterm for nonterm in nonterminals] \
        + ["unary_%s" % preterm for preterm in preterminals]
    guide = AutoContinuous(poutine.block(model, expose=expose_params))
    elbo = Trace_ELBO()
    optim = Adam({"lr": 0.001})
    svi = SVI(model, guide, optim, elbo)

    for _ in range(30):
        loss = svi.step(observed[0])
        print(loss)
Exemplo n.º 29
0
    def _create_autoguide(
        self,
        model,
        amortised,
        encoder_kwargs,
        data_transform,
        encoder_mode,
        init_loc_fn=init_to_mean(fallback=init_to_feasible),
        n_cat_list: list = [],
        encoder_instance=None,
        guide_class=AutoNormal,
        guide_kwargs: Optional[dict] = None,
    ):

        if guide_kwargs is None:
            guide_kwargs = dict()

        if not amortised:
            if getattr(model, "discrete_variables", None) is not None:
                model = poutine.block(model, hide=model.discrete_variables)
            if issubclass(guide_class, poutine.messenger.Messenger):
                # messenger guides don't need create_plates function
                _guide = guide_class(
                    model,
                    init_loc_fn=init_loc_fn,
                    **guide_kwargs,
                )
            else:
                _guide = guide_class(
                    model,
                    init_loc_fn=init_loc_fn,
                    **guide_kwargs,
                    create_plates=self.model.create_plates,
                )
        else:
            encoder_kwargs = encoder_kwargs if isinstance(
                encoder_kwargs, dict) else dict()
            n_hidden = encoder_kwargs[
                "n_hidden"] if "n_hidden" in encoder_kwargs.keys() else 200
            if isinstance(data_transform, np.ndarray):
                # add extra info about gene clusters as input to NN
                self.register_buffer(
                    "gene_clusters",
                    torch.tensor(data_transform.astype("float32")))
                n_in = model.n_vars + data_transform.shape[1]
                data_transform = self._data_transform_clusters()
            elif data_transform == "log1p":
                # use simple log1p transform
                data_transform = torch.log1p
                n_in = self.model.n_vars
            elif (isinstance(data_transform, dict)
                  and "var_std" in list(data_transform.keys())
                  and "var_mean" in list(data_transform.keys())):
                # use data transform by scaling
                n_in = model.n_vars
                self.register_buffer(
                    "var_mean",
                    torch.tensor(
                        data_transform["var_mean"].astype("float32").reshape(
                            (1, n_in))),
                )
                self.register_buffer(
                    "var_std",
                    torch.tensor(
                        data_transform["var_std"].astype("float32").reshape(
                            (1, n_in))),
                )
                data_transform = self._data_transform_scale()
            else:
                # use custom data transform
                data_transform = data_transform
                n_in = model.n_vars
            amortised_vars = model.list_obs_plate_vars()
            if len(amortised_vars["input"]) >= 2:
                encoder_kwargs["n_cat_list"] = n_cat_list
            amortised_vars["input_transform"][0] = data_transform
            if getattr(model, "discrete_variables", None) is not None:
                model = poutine.block(model, hide=model.discrete_variables)
            _guide = AutoAmortisedHierarchicalNormalMessenger(
                model,
                amortised_plate_sites=amortised_vars,
                n_in=n_in,
                n_hidden=n_hidden,
                encoder_kwargs=encoder_kwargs,
                encoder_mode=encoder_mode,
                encoder_instance=encoder_instance,
                **guide_kwargs,
            )
        return _guide
                                     len(mix_params),
                                     device=x.device))
    beta_scale = pyro.param(
        'beta_resp_scale',
        torch.tril(
            1. * torch.eye(len(mix_params), len(mix_params), device=x.device)),
        constraint=constraints.lower_cholesky)
    pyro.sample(
        "beta_resp",
        dist.MultivariateNormal(beta_loc, scale_tril=beta_scale).to_event(1))


# In[11]:

guide = AutoGuideList(model)
guide.add(AutoDiagonalNormal(poutine.block(model, expose=['theta',
                                                          'L_omega'])))
guide.add(my_local_guide)  # automatically wrapped in an AutoCallable

# # Run variational inference

# In[12]:

# prepare data for running inference
train_x = torch.tensor(alt_attributes, dtype=torch.float)
train_x = train_x.cuda()
train_y = torch.tensor(true_choices, dtype=torch.int)
train_y = train_y.cuda()
alt_av_cuda = torch.from_numpy(alt_availability)
alt_av_cuda = alt_av_cuda.cuda()
alt_av_mat = alt_availability.copy()
alt_av_mat[np.where(alt_av_mat == 0)] = -1e9
Exemplo n.º 31
0
def trainVI(
    data,
    hidden_dim,
    learning_rate=1e-03,
    n_iters=500,
    trace_embeddings_interval=20,
    writer=None,
    moved_points={},
    sigma_fix=1e-3,
):
    """Train (Deep) PPCA model with VI.
    Using tensorboardX SummaryWriter to log to tensorboard.
    Logging the internal result by setting "trace_embeddings_interval"
    to the expected interval (50, 100, ...).
    Set it to value larger than "n_iters" to disable interval logging.
    """
    pyro.enable_validation(True)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    model = partial(
        ppca_model,
        hidden_dim=hidden_dim,
        z_dim=2,
        moved_points=moved_points,
        sigma_fix=sigma_fix,
    )
    guide = AutoGuideList(model)
    guide.add(AutoDiagonalNormal(model=poutine.block(model, expose=["sigma"])))
    guide.add(AutoDiagonalNormal(model=poutine.block(model, expose=["Z"]), prefix="qZ"))

    optim = Adam({"lr": learning_rate})
    svi = SVI(model, guide, optim, loss=Trace_ELBO())

    fig_title = f"lr={learning_rate}/hidden-dim={hidden_dim}"
    metric = DRMetric(X=data, Y=None)
    data = torch.tensor(data, dtype=torch.float)

    for n_iter in tqdm(range(n_iters)):
        loss = svi.step(data)
        if writer and n_iter % 10 == 0:
            writer.add_scalar("train_vi/loss", loss, n_iter)

        if writer and n_iter % trace_embeddings_interval == 0:
            z2d_loc = pyro.param("qZ_loc").reshape(-1, 2).data.numpy()

            auc_rnx = metric.update(Y=z2d_loc).auc_rnx()
            writer.add_scalar("metrics/auc_rnx", auc_rnx, n_iter)

            fig = get_fig_plot_z2d(z2d_loc, fig_title + f", auc_rnx={auc_rnx:.3f}")
            writer.add_figure("train_vi/z2d", fig, n_iter)

    # show named rvs
    print("List params and their size: ")
    for p_name, p_val in pyro.get_param_store().items():
        print(p_name, p_val.shape)

    z2d_loc = pyro.param("qZ_loc").reshape(-1, 2).data.numpy()
    z2d_scale = pyro.param("qZ_scale").reshape(-1, 2).data.numpy()
    if writer:
        fig = get_fig_plot_z2d(z2d_loc, fig_title)
        writer.add_figure("train_vi/z2d", fig, n_iters)
    return z2d_loc, z2d_scale
Exemplo n.º 32
0
Arquivo: hmm.py Projeto: pyro-ppl/pyro
def main(args):
    if args.cuda:
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    logging.info("Loading data")
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info("-" * 40)
    model = models[args.model]
    logging.info("Training {} on {} sequences".format(
        model.__name__, len(data["train"]["sequences"])))
    sequences = data["train"]["sequences"]
    lengths = data["train"]["sequence_lengths"]

    # find all the notes that are present at least once in the training set
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optim = Adam({"lr": args.learning_rate})
    if args.tmc:
        if args.jit:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = poutine.infer_config(
            model,
            lambda msg: {
                "num_samples": args.tmc_num_samples,
                "expand": False
            } if msg["infer"].get("enumerate", None) == "parallel" else {},
        )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
        elbo = Elbo(
            max_plate_nesting=1 if model is model_0 else 2,
            strict_enumeration_warning=(model is not model_7),
            jit_options={"time_compilation": args.time_compilation},
        )
        svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info("Step\tLoss")
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info("{: >5d}\t{}".format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug("time to compile: {} s.".format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info("training loss = {}".format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info("-" * 40)
    logging.info("Evaluating on {} test sequences".format(
        len(data["test"]["sequences"])))
    sequences = data["test"]["sequences"][..., present_notes]
    lengths = data["test"]["sequence_lengths"]
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info("test loss = {}".format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info("{} capacity = {} parameters".format(model.__name__,
                                                      capacity))
Exemplo n.º 33
0
    def fit_advi_iterative(self,
                           n=3,
                           method='advi',
                           n_type='restart',
                           n_iter=None,
                           learning_rate=None,
                           progressbar=True,
                           num_workers=2,
                           train_proportion=None,
                           stratify_cv=None,
                           l2_weight=False,
                           sample_scaling_weight=0.5,
                           checkpoints=None,
                           checkpoint_dir='./checkpoints',
                           tracking=False):
        r""" Train posterior using ADVI method.
        (maximising likehood of the data and minimising KL-divergence of posterior to prior)
        :param n: number of independent initialisations
        :param method: to allow for potential use of SVGD or MCMC (currently only ADVI implemented).
        :param n_type: type of repeated initialisation:
                                  'restart' to pick different initial value,
                                  'cv' for molecular cross-validation - splits counts into n datasets,
                                         for now, only n=2 is implemented
                                  'bootstrap' for fitting the model to multiple downsampled datasets.
                                         Run `mod.bootstrap_data()` to generate variants of data
        :param n_iter: number of iterations, supersedes self.n_iter
        :param train_proportion: if not None, which proportion of cells to use for training and which for validation.
        :param checkpoints: int, list of int's or None, number of checkpoints to save while model training or list of
            iterations to save checkpoints on
        :param checkpoint_dir: str, directory to save checkpoints in
        :param tracking: bool, track all latent variables during training - if True makes training 2 times slower
        :return: None
        """

        # initialise parameter store
        self.svi = {}
        self.hist = {}
        self.guide_i = {}
        self.samples = {}
        self.node_samples = {}

        if tracking:
            self.logp_hist = {}

        if n_iter is None:
            n_iter = self.n_iter

        if type(checkpoints) is int:
            if n_iter < checkpoints:
                checkpoints = n_iter
            checkpoints = np.linspace(0, n_iter, checkpoints + 1,
                                      dtype=int)[1:]
            self.checkpoints = list(checkpoints)
        else:
            self.checkpoints = checkpoints

        self.checkpoint_dir = checkpoint_dir

        self.n_type = n_type
        self.l2_weight = l2_weight
        self.sample_scaling_weight = sample_scaling_weight
        self.train_proportion = train_proportion

        if stratify_cv is not None:
            self.stratify_cv = stratify_cv

        if train_proportion is not None:
            self.validation_hist = {}
            self.training_hist = {}
            if tracking:
                self.logp_hist_val = {}
                self.logp_hist_train = {}

        if learning_rate is None:
            learning_rate = self.learning_rate

        if np.isin(n_type, ['bootstrap']):
            if self.X_data_sample is None:
                self.bootstrap_data(n=n)
        elif np.isin(n_type, ['cv']):
            self.generate_cv_data()  # cv data added to self.X_data_sample

        init_names = ['init_' + str(i + 1) for i in np.arange(n)]

        for i, name in enumerate(init_names):
            ################### Initialise parameters & optimiser ###################
            # initialise Variational distribution = guide
            if method is 'advi':
                self.guide_i[name] = AutoGuideList(self.model)
                normal_guide_block = poutine.block(
                    self.model,
                    expose_all=True,
                    hide_all=False,
                    hide=self.point_estim +
                    flatten_iterable(self.custom_guides.keys()))
                self.guide_i[name].append(
                    AutoNormal(normal_guide_block, init_loc_fn=init_to_mean))
                self.guide_i[name].append(
                    AutoDelta(
                        poutine.block(self.model,
                                      hide_all=True,
                                      expose=self.point_estim)))
                for k, v in self.custom_guides.items():
                    self.guide_i[name].append(v)

            elif method is 'custom':
                self.guide_i[name] = self.guide

            # initialise SVI inference method
            self.svi[name] = SVI(
                self.model,
                self.guide_i[name],
                optim.ClippedAdam({
                    'lr': learning_rate,
                    # limit the gradient step from becoming too large
                    'clip_norm': self.total_grad_norm_constraint
                }),
                loss=JitTrace_ELBO())

            pyro.clear_param_store()

            self.set_initial_values()

            # record ELBO Loss history here
            self.hist[name] = []
            if tracking:
                self.logp_hist[name] = defaultdict(list)

            if train_proportion is not None:
                self.validation_hist[name] = []
                if tracking:
                    self.logp_hist_val[name] = defaultdict(list)

            ################### Select data for this iteration ###################
            if np.isin(n_type, ['cv', 'bootstrap']):
                X_data = self.X_data_sample[i].astype(self.data_type)
            else:
                X_data = self.X_data.astype(self.data_type)

            ################### Training / validation split ###################
            # split into training and validation
            if train_proportion is not None:
                idx = np.arange(len(X_data))
                train_idx, val_idx = train_test_split(
                    idx,
                    train_size=train_proportion,
                    shuffle=True,
                    stratify=self.stratify_cv)

                extra_data_val = {
                    k: torch.FloatTensor(v[val_idx]).to(self.device)
                    for k, v in self.extra_data.items()
                }
                extra_data_train = {
                    k: torch.FloatTensor(v[train_idx])
                    for k, v in self.extra_data.items()
                }

                x_data_val = torch.FloatTensor(X_data[val_idx]).to(self.device)
                x_data = torch.FloatTensor(X_data[train_idx])
            else:
                # just convert data to CPU tensors
                x_data = torch.FloatTensor(X_data)
                extra_data_train = {
                    k: torch.FloatTensor(v)
                    for k, v in self.extra_data.items()
                }

            ################### Move data to cuda - FULL data ###################
            # if not minibatch do this:
            if self.minibatch_size is None:
                # move tensors to CUDA
                x_data = x_data.to(self.device)
                for k in extra_data_train.keys():
                    extra_data_train[k] = extra_data_train[k].to(self.device)
                # extra_data_train = {k: v.to(self.device) for k, v in extra_data_train.items()}

            ################### MINIBATCH data ###################
            else:
                # create minibatches
                dataset = MiniBatchDataset(x_data,
                                           extra_data_train,
                                           return_idx=True)
                loader = DataLoader(dataset,
                                    batch_size=self.minibatch_size,
                                    num_workers=0)  # TODO num_workers

            ################### Training the model ###################
            # start training in epochs
            epochs_iterator = tqdm(range(n_iter))
            for epoch in epochs_iterator:

                if self.minibatch_size is None:
                    ################### Training FULL data ###################
                    iter_loss = self.step_train(name, x_data, extra_data_train)

                    self.hist[name].append(iter_loss)
                    # save data for posterior sampling
                    self.x_data = x_data
                    self.extra_data_train = extra_data_train

                    if tracking:
                        guide_tr, model_tr = self.step_trace(
                            name, x_data, extra_data_train)
                        self.logp_hist[name]['guide'].append(
                            guide_tr.log_prob_sum().item())
                        self.logp_hist[name]['model'].append(
                            model_tr.log_prob_sum().item())

                        for k, v in model_tr.nodes.items():
                            if "log_prob_sum" in v:
                                self.logp_hist[name][k].append(
                                    v["log_prob_sum"].item())

                else:
                    ################### Training MINIBATCH data ###################
                    aver_loss = []
                    if tracking:
                        aver_logp_guide = []
                        aver_logp_model = []
                        aver_logp = defaultdict(list)

                    for batch in loader:

                        x_data_batch, extra_data_batch = batch
                        x_data_batch = x_data_batch.to(self.device)
                        extra_data_batch = {
                            k: v.to(self.device)
                            for k, v in extra_data_batch.items()
                        }

                        loss = self.step_train(name, x_data_batch,
                                               extra_data_batch)

                        if tracking:
                            guide_tr, model_tr = self.step_trace(
                                name, x_data_batch, extra_data_batch)
                            aver_logp_guide.append(
                                guide_tr.log_prob_sum().item())
                            aver_logp_model.append(
                                model_tr.log_prob_sum().item())

                            for k, v in model_tr.nodes.items():
                                if "log_prob_sum" in v:
                                    aver_logp[k].append(
                                        v["log_prob_sum"].item())

                        aver_loss.append(loss)

                    iter_loss = np.sum(aver_loss)

                    # save data for posterior sampling
                    self.x_data = x_data_batch
                    self.extra_data_train = extra_data_batch

                    self.hist[name].append(iter_loss)

                    if tracking:
                        iter_logp_guide = np.sum(aver_logp_guide)
                        iter_logp_model = np.sum(aver_logp_model)
                        self.logp_hist[name]['guide'].append(iter_logp_guide)
                        self.logp_hist[name]['model'].append(iter_logp_model)

                        for k, v in aver_logp.items():
                            self.logp_hist[name][k].append(np.sum(v))

                if self.checkpoints is not None:
                    if (epoch + 1) in self.checkpoints:
                        self.save_checkpoint(epoch + 1, prefix=name)

                ################### Evaluating cross-validation loss ###################
                if train_proportion is not None:

                    iter_loss_val = self.step_eval_loss(
                        name, x_data_val, extra_data_val)

                    if tracking:
                        guide_tr, model_tr = self.step_trace(
                            name, x_data_val, extra_data_val)
                        self.logp_hist_val[name]['guide'].append(
                            guide_tr.log_prob_sum().item())
                        self.logp_hist_val[name]['model'].append(
                            model_tr.log_prob_sum().item())

                        for k, v in model_tr.nodes.items():
                            if "log_prob_sum" in v:
                                self.logp_hist_val[name][k].append(
                                    v["log_prob_sum"].item())

                    self.validation_hist[name].append(iter_loss_val)
                    epochs_iterator.set_description(f'ELBO Loss: ' + '{:.4e}'.format(iter_loss) \
                                                    + ': Val loss: ' + '{:.4e}'.format(iter_loss_val))
                else:
                    epochs_iterator.set_description('ELBO Loss: ' +
                                                    '{:.4e}'.format(iter_loss))

                if epoch % 20 == 0:
                    torch.cuda.empty_cache()

            if train_proportion is not None:
                # rescale loss
                self.validation_hist[name] = [
                    i / (1 - train_proportion)
                    for i in self.validation_hist[name]
                ]
                self.hist[name] = [
                    i / train_proportion for i in self.hist[name]
                ]

                # reassing the main loss to be displayed
                self.training_hist[name] = self.hist[name]
                self.hist[name] = self.validation_hist[name]

                if tracking:
                    for k, v in self.logp_hist[name].items():
                        self.logp_hist[name][k] = [
                            i / train_proportion
                            for i in self.logp_hist[name][k]
                        ]
                        self.logp_hist_val[name][k] = [
                            i / (1 - train_proportion)
                            for i in self.logp_hist_val[name][k]
                        ]

                    self.logp_hist_train[name] = self.logp_hist[name]
                    self.logp_hist[name] = self.logp_hist_val[name]

            if self.verbose:
                print(plt.plot(np.log10(self.hist[name][0:])))
Exemplo n.º 34
0
    def fit_advi_iterative_simple(
        self,
        n: int = 3,
        method='advi',
        n_type='restart',
        n_iter=None,
        learning_rate=None,
        progressbar=True,
    ):
        r""" Find posterior using ADVI (deprecated)
        (maximising likehood of the data and minimising KL-divergence of posterior to prior)
        :param n: number of independent initialisations
        :param method: which approximation of the posterior (guide) to use?.
            * ``'advi'`` - Univariate normal approximation (pyro.infer.autoguide.AutoDiagonalNormal)
            * ``'custom'`` - Custom guide using conjugate posteriors
        :return: self.svi dictionary with svi pyro objects for each n, and sefl.elbo dictionary storing training history. 
        """

        # Pass data to pyro / pytorch
        self.x_data = torch.tensor(self.X_data.astype(
            self.data_type))  # .double()

        # initialise parameter store
        self.svi = {}
        self.hist = {}
        self.guide_i = {}
        self.samples = {}
        self.node_samples = {}

        self.n_type = n_type

        if n_iter is None:
            n_iter = self.n_iter

        if learning_rate is None:
            learning_rate = self.learning_rate

        if np.isin(n_type, ['bootstrap']):
            if self.X_data_sample is None:
                self.bootstrap_data(n=n)
        elif np.isin(n_type, ['cv']):
            self.generate_cv_data()  # cv data added to self.X_data_sample

        init_names = ['init_' + str(i + 1) for i in np.arange(n)]

        for i, name in enumerate(init_names):

            # initialise Variational distributiion = guide
            if method is 'advi':
                self.guide_i[name] = AutoGuideList(self.model)
                self.guide_i[name].append(
                    AutoNormal(poutine.block(self.model,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=self.point_estim),
                               init_loc_fn=init_to_mean))
                self.guide_i[name].append(
                    AutoDelta(
                        poutine.block(self.model,
                                      hide_all=True,
                                      expose=self.point_estim)))
            elif method is 'custom':
                self.guide_i[name] = self.guide

            # initialise SVI inference method
            self.svi[name] = SVI(
                self.model,
                self.guide_i[name],
                optim.ClippedAdam({
                    'lr': learning_rate,
                    # limit the gradient step from becoming too large
                    'clip_norm': self.total_grad_norm_constraint
                }),
                loss=JitTrace_ELBO())

            pyro.clear_param_store()

            # record ELBO Loss history here
            self.hist[name] = []

            # pick dataset depending on the training mode and move to GPU
            if np.isin(n_type, ['cv', 'bootstrap']):
                self.x_data = torch.tensor(self.X_data_sample[i].astype(
                    self.data_type))
            else:
                self.x_data = torch.tensor(self.X_data.astype(self.data_type))

            if self.use_cuda:
                # move tensors and modules to CUDA
                self.x_data = self.x_data.cuda()

            # train for n_iter
            it_iterator = tqdm(range(n_iter))
            for it in it_iterator:

                hist = self.svi[name].step(self.x_data)
                it_iterator.set_description('ELBO Loss: ' +
                                            str(np.round(hist, 3)))
                self.hist[name].append(hist)

                # if it % 50 == 0 & self.verbose:
                # logging.info("Elbo loss: {}".format(hist))
                if it % 500 == 0:
                    torch.cuda.empty_cache()
Exemplo n.º 35
0
 def step(self, state):
     with poutine.block(hide_types=["observe"]):
         super().step(state.copy())
Exemplo n.º 36
0
    def _quantized_model(self):
        """
        Quantized vectorized model used for parallel-scan enumerated inference.
        This method is called only outside particle_plate.
        """
        C = len(self.compartments)
        T = self.duration
        Q = self.num_quant_bins
        R_shape = getattr(self.population, "shape", ())  # Region shape.

        # Sample global parameters and auxiliary variables.
        params = self.global_model()
        auxiliary, non_compartmental = self._sample_auxiliary()

        # Manually enumerate.
        curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population,
                                        num_quant_bins=self.num_quant_bins)
        curr = OrderedDict(zip(self.compartments, curr.unbind(0)))
        logp = OrderedDict(zip(self.compartments, logp.unbind(0)))
        curr.update(non_compartmental)

        # Truncate final value from the right then pad initial value onto the left.
        init = self.initialize(params)
        prev = {}
        for name, value in init.items():
            if name in self.compartments:
                if isinstance(value, torch.Tensor):
                    value = value[..., None]  # Because curr is enumerated on the right.
                prev[name] = cat2(value, curr[name][:-1],
                                  dim=-3 if self.is_regional else -2)
            else:  # non-compartmental
                prev[name] = cat2(init[name], curr[name][:-1], dim=-curr[name].dim())

        # Reshape to support broadcasting, similar to EnumMessenger.
        def enum_reshape(tensor, position):
            assert tensor.size(-1) == Q
            assert tensor.dim() <= self.max_plate_nesting + 2
            tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1))
            shape = [Q] + [1] * (position + self.max_plate_nesting - (tensor.dim() - 2))
            shape.extend(tensor.shape[1:])
            return tensor.reshape(shape)

        for e, name in enumerate(self.compartments):
            curr[name] = enum_reshape(curr[name], e)
            logp[name] = enum_reshape(logp[name], e)
            prev[name] = enum_reshape(prev[name], e + C)

        # Enable approximate inference by using aux as a non-enumerated proxy
        # for enumerated compartment values.
        for name in self.approximate:
            aux = auxiliary[self.compartments.index(name)]
            curr[name + "_approx"] = aux
            prev[name + "_approx"] = cat2(init[name], aux[:-1],
                                          dim=-2 if self.is_regional else -1)

        # Record transition factors.
        with poutine.block(), poutine.trace() as tr:
            with self.time_plate:
                t = slice(0, T, 1)  # Used to slice data tensors.
                self._transition_bwd(params, prev, curr, t)
        tr.trace.compute_log_prob()
        for name, site in tr.trace.nodes.items():
            if site["type"] == "sample":
                log_prob = site["log_prob"]
                if log_prob.dim() <= self.max_plate_nesting:  # Not enumerated.
                    pyro.factor("transition_" + name, site["log_prob_sum"])
                    continue
                if self.is_regional and log_prob.shape[-1:] != R_shape:
                    # Poor man's tensor variable elimination.
                    log_prob = log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0]
                logp[name] = site["log_prob"]

        # Manually perform variable elimination.
        logp = reduce(operator.add, logp.values())
        logp = logp.reshape(Q ** C, Q ** C, T, -1)  # prev, curr, T, batch
        logp = logp.permute(3, 2, 0, 1).squeeze(0)  # batch, T, prev, curr
        logp = pyro.distributions.hmm._sequential_logmatmulexp(logp)  # batch, prev, curr
        logp = logp.reshape(-1, Q ** C * Q ** C).logsumexp(-1).sum()
        warn_if_nan(logp)
        pyro.factor("transition", logp)

        # Apply final likelihood.
        prev = {name: prev[name + "_approx"] for name in self.approximate}
        curr = {name: curr[name + "_approx"] for name in self.approximate}
        with _disallow_latent_variables(".finalize()"):
            self.finalize(params, prev, curr)

        self._clear_plates()
 def guide_autodelta(self):
     self.guide = AutoDelta(poutine.block(self.model,
                                          expose=['weights',
                                                  'locs',
                                                  'scale']))
 def guide_autodiagnorm(self):
     self.guide = AutoDiagonalNormal(poutine.block(self.model,
                                                   expose=['weights',
                                                           'locs',
                                                           'scale']))
Exemplo n.º 39
0
 def guide(self,MAP = False,*args, **kwargs):
     return AutoDelta(poutine.block(self.model, expose=['pi', 'norm_factor', 'cnv_probs']),
                          init_loc_fn=self.init_fn())
Exemplo n.º 40
0
    def _contstructSVI(self, optimizationStrategy, total_steps):
        """
        Constructs a list of SVI functions use to train model. Takes optimizationStrategy as input, which must be in ['Full','PosteriorOnly','GuideOnly']
        """

        #generate random sample just to build the param_store
        test_data = self.datasetSampler.GenerateRandomTrainingSample(2)
        if self.useCuda:
            test_data = self._sendDataToGPU(test_data)
        self.phenotypeModel.guide(*test_data)
        param_store = pyro.get_param_store()

        if optimizationStrategy == 'Full':
            _model = self.phenotypeModel.model
            _guide = self.phenotypeModel.guide

        else:
            if optimizationStrategy == 'DecoderOnly':
                _model = self.phenotypeModel.model
                _guide = poutine.block(
                    self.phenotypeModel.guide,
                    hide=[
                        x for x in param_store.get_all_param_names()
                        if ('encoder' in x)
                    ])

            else:
                _model = poutine.block(
                    self.phenotypeModel.model,
                    hide=[
                        x for x in param_store.get_all_param_names()
                        if ('decoder' in x)
                    ])
                _guide = self.phenotypeModel.guide

        #initialize the SVI class
        AdamWOptimArgs = {'weight_decay': self.AdamW_Weight_Decay}

        scheduler = OneCycleLR({
            'optimizer':
            AdamW,
            'optim_args':
            AdamWOptimArgs,
            'max_lr':
            self.maxLearningRate,
            'total_steps':
            total_steps,
            'pct_start':
            self.OneCycleParams['pctCycleIncrease'],
            'div_factor':
            self.OneCycleParams['initLRDivisionFactor'],
            'final_div_factor':
            self.OneCycleParams['finalLRDivisionFactor']
        })
        return {
            'svi':
            SVI(_model,
                _guide,
                scheduler,
                loss=Trace_ELBO(num_particles=self.numParticles)),
            'scheduler':
            scheduler
        }
 def guide_multivariatenormal(self):
     self.guide = AutoMultivariateNormal(poutine.block(self.model,
                                                       expose=['weights',
                                                               'locs',
                                                               'scale']))
Exemplo n.º 42
0
def main(args):
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    logging.info('Loading data')
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info('-' * 40)
    model = models[args.model]
    logging.info('Training {} on {} sequences'.format(
        model.__name__, len(data['train']['sequences'])))
    sequences = data['train']['sequences']
    lengths = data['train']['sequence_lengths']

    # find all the notes that are present at least once in the training set
    present_notes = ((sequences == 1).sum(0).sum(0) > 0)
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths.clamp_(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(True)

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
    elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
    optim = Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info('{: >5d}\t{}'.format(step, loss / num_observations))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info('training loss = {}'.format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info('-' * 40)
    logging.info('Evaluating on {} test sequences'.format(
        len(data['test']['sequences'])))
    sequences = data['test']['sequences'][..., present_notes]
    lengths = data['test']['sequence_lengths']
    if args.truncate:
        lengths.clamp_(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info('test loss = {}'.format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info('{} capacity = {} parameters'.format(model.__name__,
                                                      capacity))
Exemplo n.º 43
0
    tr = poutine.trace(guide).get_trace(torch.tensor(2.))
    exp_msg = r"Error while computing score_parts at site 'test_site':\s*" \
              r"The value argument must be within the support"
    with pytest.raises(ValueError, match=exp_msg):
        tr.compute_score_parts()


def _model(a=torch.tensor(1.), b=torch.tensor(1.)):
    latent = pyro.sample("latent", dist.Beta(a, b))
    return pyro.sample("test_site",
                       dist.Bernoulli(latent),
                       obs=torch.tensor(1))


@pytest.mark.parametrize('wrapper', [
    lambda fn: poutine.block(fn),
    lambda fn: poutine.condition(fn, {'latent': 0.9}),
    lambda fn: poutine.enum(fn, -1),
    lambda fn: poutine.replay(fn,
                              poutine.trace(fn).get_trace()),
])
def test_pickling(wrapper):
    wrapped = wrapper(_model)
    # default protocol cannot serialize torch.Size objects (see https://github.com/pytorch/pytorch/issues/20823)
    deserialized = pickle.loads(
        pickle.dumps(wrapped, protocol=pickle.HIGHEST_PROTOCOL))
    obs = torch.tensor(0.5)
    pyro.set_rng_seed(0)
    actual_trace = poutine.trace(deserialized).get_trace(obs)
    pyro.set_rng_seed(0)
    expected_trace = poutine.trace(wrapped).get_trace(obs)
Exemplo n.º 44
0
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):
    # For internal use by infer_discrete.

    # Create an enumerated trace.
    with poutine.block(), EnumerateMessenger(first_available_dim):
        enum_trace = poutine.trace(model).get_trace(*args, **kwargs)
    enum_trace = prune_subsample_sites(enum_trace)
    enum_trace.compute_log_prob()
    enum_trace.pack_tensors()
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    log_probs = OrderedDict()
    sum_dims = set()
    queries = []
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized)
            log_prob = node["packed"]["log_prob"]
            log_probs.setdefault(ordinal, []).append(log_prob)
            sum_dims.update(log_prob._pyro_dims)
            for frame in node["cond_indep_stack"]:
                if frame.vectorized:
                    sum_dims.remove(plate_to_symbol[frame.name])
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    ring = _make_ring(temperature)
    log_probs = contract_tensor_tree(log_probs, sum_dims,
                                     ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not f.vectorized or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)
Exemplo n.º 45
0
def main(args):
    # setup hyperparameters for the model
    hypers = {
        'expected_sparsity': max(1.0, args.num_dimensions / 10),
        'alpha1': 3.0,
        'beta1': 1.0,
        'alpha2': 3.0,
        'beta2': 1.0,
        'alpha3': 1.0,
        'c': 1.0
    }

    P = args.num_dimensions
    S = args.active_dimensions
    Q = args.quadratic_dimensions

    # generate artificial dataset
    X, Y, expected_thetas, expected_quad_dims = get_data(N=args.num_data,
                                                         P=P,
                                                         S=S,
                                                         Q=Q,
                                                         sigma_obs=args.sigma)

    loss_fn = Trace_ELBO().differentiable_loss

    # We initialize the AutoDelta guide (for MAP estimation) with args.num_trials many
    # initial parameters sampled from the vicinity of the median of the prior distribution
    # and then continue optimizing with the best performing initialization.
    init_losses = []
    for restart in range(args.num_restarts):
        pyro.clear_param_store()
        pyro.set_rng_seed(restart)
        guide = AutoDelta(model, init_loc_fn=init_loc_fn)
        with torch.no_grad():
            init_losses.append(loss_fn(model, guide, X, Y, hypers).item())

    pyro.set_rng_seed(np.argmin(init_losses))
    pyro.clear_param_store()
    guide = AutoDelta(model, init_loc_fn=init_loc_fn)

    # Instead of using pyro.infer.SVI and pyro.optim we instead construct our own PyTorch
    # optimizer and take charge of gradient-based optimization ourselves.
    with poutine.block(), poutine.trace(param_only=True) as param_capture:
        guide(X, Y, hypers)
    params = list(
        [pyro.param(name).unconstrained() for name in param_capture.trace])
    adam = Adam(params, lr=args.lr)

    report_frequency = 50
    print("Beginning MAP optimization...")

    # the optimization loop
    for step in range(args.num_steps):
        loss = loss_fn(model, guide, X, Y, hypers) / args.num_data
        loss.backward()
        adam.step()
        adam.zero_grad()

        # we manually reduce the learning rate according to this schedule
        if step in [100, 300, 700, 900]:
            adam.param_groups[0]['lr'] *= 0.2

        if step % report_frequency == 0 or step == args.num_steps - 1:
            print("[step %04d]  loss: %.5f" % (step, loss))

    print("Expected singleton thetas:\n", expected_thetas.data.numpy())

    # we do the final computation using double precision
    median = guide.median()  # == mode for MAP inference
    active_dims, active_quad_dims = \
        compute_posterior_stats(X.double(), Y.double(), median['msq'].double(),
                                median['lambda'].double(), median['eta1'].double(),
                                median['xisq'].double(), torch.tensor(hypers['c']).double(),
                                median['sigma'].double())

    expected_active_dims = np.arange(S).tolist()

    tp_singletons = len(set(active_dims) & set(expected_active_dims))
    fp_singletons = len(set(active_dims) - set(expected_active_dims))
    fn_singletons = len(set(expected_active_dims) - set(active_dims))
    singleton_stats = (tp_singletons, fp_singletons, fn_singletons)

    tp_quads = len(set(active_quad_dims) & set(expected_quad_dims))
    fp_quads = len(set(active_quad_dims) - set(expected_quad_dims))
    fn_quads = len(set(expected_quad_dims) - set(active_quad_dims))
    quad_stats = (tp_quads, fp_quads, fn_quads)

    # We report how well we did, i.e. did we recover the sparse set of coefficients
    # that we expected for our artificial dataset?
    print("[SUMMARY STATS]")
    print("Singletons (true positive, false positive, false negative): " +
          "(%d, %d, %d)" % singleton_stats)
    print("Quadratic  (true positive, false positive, false negative): " +
          "(%d, %d, %d)" % quad_stats)
Exemplo n.º 46
0
            ax.add_artist(ell)

    # Save figure
    fig.savefig(figname)


if __name__ == "__main__":
    pyro.enable_validation(True)
    pyro.set_rng_seed(42)

    # Create our model with a fixed number of components
    K = 2

    data = get_samples()

    global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scales']))
    global_guide = config_enumerate(global_guide, 'parallel')
    _, svi = initialize(data)

    true_colors = [0] * 100 + [1] * 100
    plot(data, colors=true_colors, figname='pyro_init.png')

    for i in range(151):
        svi.step(data)

        if i % 50 == 0:
            locs = pyro.param('locs')
            scales = pyro.param('scales')
            weights = pyro.param('weights')
            assignment_probs = pyro.param('assignment_probs')