Beispiel #1
0
def AutoMixed(model_full, init_loc={}, delta=None):
    guide = AutoGuideList(model_full)

    marginalised_guide_block = poutine.block(model_full,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=['tau'])
    if delta is None:
        guide.append(
            AutoNormal(marginalised_guide_block,
                       init_loc_fn=autoguide.init_to_value(values=init_loc),
                       init_scale=0.05))
    elif delta == 'part' or delta == 'all':
        guide.append(
            AutoDelta(marginalised_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    full_rank_guide_block = poutine.block(model_full,
                                          hide_all=True,
                                          expose=['tau'])
    if delta is None or delta == 'part':
        guide.append(
            AutoMultivariateNormal(
                full_rank_guide_block,
                init_loc_fn=autoguide.init_to_value(values=init_loc),
                init_scale=0.05))
    else:
        guide.append(
            AutoDelta(full_rank_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    return guide
Beispiel #2
0
def test_svi_model_side_enumeration(model, temperature):
    # Perform fake inference.
    # This has the wrong distribution but the right type for tests.
    guide = AutoNormal(
        handlers.enum(
            handlers.block(infer.config_enumerate(model),
                           expose=["loc", "scale"])))
    guide()  # Initialize but don't bother to train.
    guide_trace = handlers.trace(guide).get_trace()
    guide_data = {
        name: site["value"]
        for name, site in guide_trace.nodes.items() if site["type"] == "sample"
    }

    # MAP estimate discretes, conditioned on posterior sampled continous latents.
    actual_trace = handlers.trace(
        infer.infer_discrete(
            # TODO support replayed sites in infer_discrete.
            # handlers.replay(infer.config_enumerate(model), guide_trace)
            handlers.condition(infer.config_enumerate(model), guide_data),
            temperature=temperature,
        )).get_trace()

    # Check site names and shapes.
    expected_trace = handlers.trace(model).get_trace()
    assert set(actual_trace.nodes) == set(expected_trace.nodes)
    assert "z1" not in actual_trace.nodes["scale"]["funsor"]["value"].inputs
Beispiel #3
0
    def __init__(self, **kwargs):

        super().__init__()
        self._model = BayesianRegressionPyroModel(**kwargs)
        self._guide = AutoNormal(self.model,
                                 init_loc_fn=init_to_mean,
                                 create_plates=self.model.create_plates)
        self._get_fn_args_from_batch = self._model._get_fn_args_from_batch
Beispiel #4
0
    def __init__(self, model, data, covariates, *,
                 guide=None,
                 init_loc_fn=init_to_sample,
                 init_scale=0.1,
                 create_plates=None,
                 optim=None,
                 learning_rate=0.01,
                 betas=(0.9, 0.99),
                 learning_rate_decay=0.1,
                 clip_norm=10.0,
                 dct_gradients=False,
                 subsample_aware=False,
                 num_steps=1001,
                 num_particles=1,
                 vectorize_particles=True,
                 warm_start=False,
                 log_every=100):
        assert data.size(-2) == covariates.size(-2)
        super().__init__()
        self.model = model
        if guide is None:
            guide = AutoNormal(self.model, init_loc_fn=init_loc_fn, init_scale=init_scale,
                               create_plates=create_plates)
        self.guide = guide

        # Initialize.
        if warm_start:
            model = PrefixWarmStartMessenger()(model)
            guide = PrefixWarmStartMessenger()(guide)
        if dct_gradients:
            model = MarkDCTParamMessenger("time")(model)
            guide = MarkDCTParamMessenger("time")(guide)
        elbo = Trace_ELBO(num_particles=num_particles,
                          vectorize_particles=vectorize_particles)
        elbo._guess_max_plate_nesting(model, guide, (data, covariates), {})
        elbo.max_plate_nesting = max(elbo.max_plate_nesting, 1)  # force a time plate

        losses = []
        if num_steps:
            if optim is None:
                optim = DCTAdam({"lr": learning_rate, "betas": betas,
                                 "lrd": learning_rate_decay ** (1 / num_steps),
                                 "clip_norm": clip_norm,
                                 "subsample_aware": subsample_aware})
            svi = SVI(self.model, self.guide, optim, elbo)
            for step in range(num_steps):
                loss = svi.step(data, covariates) / data.numel()
                if log_every and step % log_every == 0:
                    logger.info("step {: >4d} loss = {:0.6g}".format(step, loss))
                    print("step {: >4d} loss = {:0.6g}".format(step, loss))
                losses.append(loss)

        self.guide.create_plates = None  # Disable subsampling after training.
        self.max_plate_nesting = elbo.max_plate_nesting
        self.losses = losses
Beispiel #5
0
def test_subsample_smoke(Reparam, subsample):
    def model():
        with poutine.reparam(config={"x": Reparam()}):
            with pyro.plate("plate", 10):
                return pyro.sample("x", dist.Stable(1.5, 0))

    def create_plates():
        return pyro.plate("plate", 10, subsample_size=3)

    guide = AutoNormal(model, create_plates=create_plates if subsample else None)
    Trace_ELBO().loss(model, guide)  # smoke test
Beispiel #6
0
def test_end_to_end(model):
    # Test training.
    model = AutoReparam()(model)
    guide = AutoNormal(model)
    svi = SVI(model, guide, Adam({"lr": 1e-9}), Trace_ELBO())
    for step in range(3):
        svi.step()

    # Test prediction.
    predictive = Predictive(model, guide=guide, num_samples=2)
    samples = predictive()
    assert set("abc").issubset(samples.keys())
Beispiel #7
0
    def __init__(self, **kwargs):

        super().__init__()
        self.hist = []

        self._model = LocationModelLinearDependentWMultiExperimentModel(
            **kwargs)
        self._guide = AutoNormal(
            self.model,
            init_loc_fn=init_to_mean,
            create_plates=self.model.create_plates,
        )
Beispiel #8
0
def test_subsample_guide(auto_class, init_fn):

    # The model from tutorial/source/easyguide.ipynb
    def model(batch, subsample, full_size):
        num_time_steps = len(batch)
        result = [None] * num_time_steps
        drift = pyro.sample("drift", dist.LogNormal(-1, 0.5))
        plate = pyro.plate("data", full_size, subsample=subsample)
        assert plate.size == 50
        with plate:
            z = 0.
            for t in range(num_time_steps):
                z = pyro.sample("state_{}".format(t), dist.Normal(z, drift))
                result[t] = pyro.sample("obs_{}".format(t),
                                        dist.Bernoulli(logits=z),
                                        obs=batch[t])

        return torch.stack(result)

    def create_plates(batch, subsample, full_size):
        return pyro.plate("data", full_size, subsample=subsample)

    if auto_class == AutoGuideList:
        guide = AutoGuideList(model, create_plates=create_plates)
        guide.add(AutoDelta(poutine.block(model, expose=["drift"])))
        guide.add(AutoNormal(poutine.block(model, hide=["drift"])))
    else:
        guide = auto_class(model, create_plates=create_plates)

    full_size = 50
    batch_size = 20
    num_time_steps = 8
    pyro.set_rng_seed(123456789)
    data = model([None] * num_time_steps, torch.arange(full_size), full_size)
    assert data.shape == (num_time_steps, full_size)

    pyro.get_param_store().clear()
    pyro.set_rng_seed(123456789)
    svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
    for epoch in range(2):
        beg = 0
        while beg < full_size:
            end = min(full_size, beg + batch_size)
            subsample = torch.arange(beg, end)
            batch = data[:, beg:end]
            beg = end
            svi.step(batch, subsample, full_size=full_size)
Beispiel #9
0
def main(args):
    # Create a model, synthetic data, and a guide.
    pyro.set_rng_seed(args.seed)
    model = Model(args.size)
    covariates = torch.randn(args.size)
    data = model(covariates)
    guide = AutoNormal(model)

    if args.horovod:
        # Initialize Horovod and set PyTorch globals.
        import horovod.torch as hvd
        hvd.init()
        torch.set_num_threads(1)
        if args.cuda:
            torch.cuda.set_device(hvd.local_rank())
    if args.cuda:
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
    device = torch.tensor(0).device

    if args.horovod:
        # Initialize parameters and broadcast to all workers.
        guide(covariates[:1], data[:1])  # Initializes model and guide.
        hvd.broadcast_parameters(guide.state_dict(), root_rank=0)
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    # Create an ELBO loss and a Pyro optimizer.
    elbo = Trace_ELBO()
    optim = Adam({"lr": args.learning_rate})

    if args.horovod:
        # Wrap the basic optimizer in a distributed optimizer.
        optim = HorovodOptimizer(optim)

    # Create a dataloader.
    dataset = torch.utils.data.TensorDataset(covariates, data)
    if args.horovod:
        # Horovod requires a distributed sampler.
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset, hvd.size(), hvd.rank())
    else:
        sampler = torch.utils.data.RandomSampler(dataset)
    config = {"batch_size": args.batch_size, "sampler": sampler}
    if args.cuda:
        config["num_workers"] = 1
        config["pin_memory"] = True
        # Try to use forkserver to spawn workers instead of fork.
        if (hasattr(mp, "_supports_context") and mp._supports_context
                and "forkserver" in mp.get_all_start_methods()):
            config["multiprocessing_context"] = "forkserver"
    dataloader = torch.utils.data.DataLoader(dataset, **config)

    # Run stochastic variational inference.
    svi = SVI(model, guide, optim, elbo)
    for epoch in range(args.num_epochs):
        if args.horovod:
            # Set rng seeds on distributed samplers. This is required.
            sampler.set_epoch(epoch)

        for step, (covariates_batch, data_batch) in enumerate(dataloader):
            loss = svi.step(covariates_batch.to(device), data_batch.to(device))

            if args.horovod:
                # Optionally average loss metric across workers.
                # You can do this with arbitrary torch.Tensors.
                loss = torch.tensor(loss)
                loss = hvd.allreduce(loss, "loss")
                loss = loss.item()

                # Print only on the rank=0 worker.
                if step % 100 == 0 and hvd.rank() == 0:
                    print("epoch {} step {} loss = {:0.4g}".format(
                        epoch, step, loss))
            else:
                if step % 100 == 0:
                    print("epoch {} step {} loss = {:0.4g}".format(
                        epoch, step, loss))

    if args.horovod:
        # After we're done with the distributed parts of the program,
        # we can shutdown all but the rank=0 worker.
        hvd.shutdown()
        if hvd.rank() != 0:
            return

    if args.outfile:
        print("saving to {}".format(args.outfile))
        torch.save({"model": model, "guide": guide}, args.outfile)
Beispiel #10
0
    def fit_advi_iterative(self,
                           n=3,
                           method='advi',
                           n_type='restart',
                           n_iter=None,
                           learning_rate=None,
                           progressbar=True,
                           num_workers=2,
                           train_proportion=None,
                           stratify_cv=None,
                           l2_weight=False,
                           sample_scaling_weight=0.5,
                           checkpoints=None,
                           checkpoint_dir='./checkpoints',
                           tracking=False):
        r""" Train posterior using ADVI method.
        (maximising likehood of the data and minimising KL-divergence of posterior to prior)
        :param n: number of independent initialisations
        :param method: to allow for potential use of SVGD or MCMC (currently only ADVI implemented).
        :param n_type: type of repeated initialisation:
                                  'restart' to pick different initial value,
                                  'cv' for molecular cross-validation - splits counts into n datasets,
                                         for now, only n=2 is implemented
                                  'bootstrap' for fitting the model to multiple downsampled datasets.
                                         Run `mod.bootstrap_data()` to generate variants of data
        :param n_iter: number of iterations, supersedes self.n_iter
        :param train_proportion: if not None, which proportion of cells to use for training and which for validation.
        :param checkpoints: int, list of int's or None, number of checkpoints to save while model training or list of
            iterations to save checkpoints on
        :param checkpoint_dir: str, directory to save checkpoints in
        :param tracking: bool, track all latent variables during training - if True makes training 2 times slower
        :return: None
        """

        # initialise parameter store
        self.svi = {}
        self.hist = {}
        self.guide_i = {}
        self.samples = {}
        self.node_samples = {}

        if tracking:
            self.logp_hist = {}

        if n_iter is None:
            n_iter = self.n_iter

        if type(checkpoints) is int:
            if n_iter < checkpoints:
                checkpoints = n_iter
            checkpoints = np.linspace(0, n_iter, checkpoints + 1,
                                      dtype=int)[1:]
            self.checkpoints = list(checkpoints)
        else:
            self.checkpoints = checkpoints

        self.checkpoint_dir = checkpoint_dir

        self.n_type = n_type
        self.l2_weight = l2_weight
        self.sample_scaling_weight = sample_scaling_weight
        self.train_proportion = train_proportion

        if stratify_cv is not None:
            self.stratify_cv = stratify_cv

        if train_proportion is not None:
            self.validation_hist = {}
            self.training_hist = {}
            if tracking:
                self.logp_hist_val = {}
                self.logp_hist_train = {}

        if learning_rate is None:
            learning_rate = self.learning_rate

        if np.isin(n_type, ['bootstrap']):
            if self.X_data_sample is None:
                self.bootstrap_data(n=n)
        elif np.isin(n_type, ['cv']):
            self.generate_cv_data()  # cv data added to self.X_data_sample

        init_names = ['init_' + str(i + 1) for i in np.arange(n)]

        for i, name in enumerate(init_names):
            ################### Initialise parameters & optimiser ###################
            # initialise Variational distribution = guide
            if method is 'advi':
                self.guide_i[name] = AutoGuideList(self.model)
                normal_guide_block = poutine.block(
                    self.model,
                    expose_all=True,
                    hide_all=False,
                    hide=self.point_estim +
                    flatten_iterable(self.custom_guides.keys()))
                self.guide_i[name].append(
                    AutoNormal(normal_guide_block, init_loc_fn=init_to_mean))
                self.guide_i[name].append(
                    AutoDelta(
                        poutine.block(self.model,
                                      hide_all=True,
                                      expose=self.point_estim)))
                for k, v in self.custom_guides.items():
                    self.guide_i[name].append(v)

            elif method is 'custom':
                self.guide_i[name] = self.guide

            # initialise SVI inference method
            self.svi[name] = SVI(
                self.model,
                self.guide_i[name],
                optim.ClippedAdam({
                    'lr': learning_rate,
                    # limit the gradient step from becoming too large
                    'clip_norm': self.total_grad_norm_constraint
                }),
                loss=JitTrace_ELBO())

            pyro.clear_param_store()

            self.set_initial_values()

            # record ELBO Loss history here
            self.hist[name] = []
            if tracking:
                self.logp_hist[name] = defaultdict(list)

            if train_proportion is not None:
                self.validation_hist[name] = []
                if tracking:
                    self.logp_hist_val[name] = defaultdict(list)

            ################### Select data for this iteration ###################
            if np.isin(n_type, ['cv', 'bootstrap']):
                X_data = self.X_data_sample[i].astype(self.data_type)
            else:
                X_data = self.X_data.astype(self.data_type)

            ################### Training / validation split ###################
            # split into training and validation
            if train_proportion is not None:
                idx = np.arange(len(X_data))
                train_idx, val_idx = train_test_split(
                    idx,
                    train_size=train_proportion,
                    shuffle=True,
                    stratify=self.stratify_cv)

                extra_data_val = {
                    k: torch.FloatTensor(v[val_idx]).to(self.device)
                    for k, v in self.extra_data.items()
                }
                extra_data_train = {
                    k: torch.FloatTensor(v[train_idx])
                    for k, v in self.extra_data.items()
                }

                x_data_val = torch.FloatTensor(X_data[val_idx]).to(self.device)
                x_data = torch.FloatTensor(X_data[train_idx])
            else:
                # just convert data to CPU tensors
                x_data = torch.FloatTensor(X_data)
                extra_data_train = {
                    k: torch.FloatTensor(v)
                    for k, v in self.extra_data.items()
                }

            ################### Move data to cuda - FULL data ###################
            # if not minibatch do this:
            if self.minibatch_size is None:
                # move tensors to CUDA
                x_data = x_data.to(self.device)
                for k in extra_data_train.keys():
                    extra_data_train[k] = extra_data_train[k].to(self.device)
                # extra_data_train = {k: v.to(self.device) for k, v in extra_data_train.items()}

            ################### MINIBATCH data ###################
            else:
                # create minibatches
                dataset = MiniBatchDataset(x_data,
                                           extra_data_train,
                                           return_idx=True)
                loader = DataLoader(dataset,
                                    batch_size=self.minibatch_size,
                                    num_workers=0)  # TODO num_workers

            ################### Training the model ###################
            # start training in epochs
            epochs_iterator = tqdm(range(n_iter))
            for epoch in epochs_iterator:

                if self.minibatch_size is None:
                    ################### Training FULL data ###################
                    iter_loss = self.step_train(name, x_data, extra_data_train)

                    self.hist[name].append(iter_loss)
                    # save data for posterior sampling
                    self.x_data = x_data
                    self.extra_data_train = extra_data_train

                    if tracking:
                        guide_tr, model_tr = self.step_trace(
                            name, x_data, extra_data_train)
                        self.logp_hist[name]['guide'].append(
                            guide_tr.log_prob_sum().item())
                        self.logp_hist[name]['model'].append(
                            model_tr.log_prob_sum().item())

                        for k, v in model_tr.nodes.items():
                            if "log_prob_sum" in v:
                                self.logp_hist[name][k].append(
                                    v["log_prob_sum"].item())

                else:
                    ################### Training MINIBATCH data ###################
                    aver_loss = []
                    if tracking:
                        aver_logp_guide = []
                        aver_logp_model = []
                        aver_logp = defaultdict(list)

                    for batch in loader:

                        x_data_batch, extra_data_batch = batch
                        x_data_batch = x_data_batch.to(self.device)
                        extra_data_batch = {
                            k: v.to(self.device)
                            for k, v in extra_data_batch.items()
                        }

                        loss = self.step_train(name, x_data_batch,
                                               extra_data_batch)

                        if tracking:
                            guide_tr, model_tr = self.step_trace(
                                name, x_data_batch, extra_data_batch)
                            aver_logp_guide.append(
                                guide_tr.log_prob_sum().item())
                            aver_logp_model.append(
                                model_tr.log_prob_sum().item())

                            for k, v in model_tr.nodes.items():
                                if "log_prob_sum" in v:
                                    aver_logp[k].append(
                                        v["log_prob_sum"].item())

                        aver_loss.append(loss)

                    iter_loss = np.sum(aver_loss)

                    # save data for posterior sampling
                    self.x_data = x_data_batch
                    self.extra_data_train = extra_data_batch

                    self.hist[name].append(iter_loss)

                    if tracking:
                        iter_logp_guide = np.sum(aver_logp_guide)
                        iter_logp_model = np.sum(aver_logp_model)
                        self.logp_hist[name]['guide'].append(iter_logp_guide)
                        self.logp_hist[name]['model'].append(iter_logp_model)

                        for k, v in aver_logp.items():
                            self.logp_hist[name][k].append(np.sum(v))

                if self.checkpoints is not None:
                    if (epoch + 1) in self.checkpoints:
                        self.save_checkpoint(epoch + 1, prefix=name)

                ################### Evaluating cross-validation loss ###################
                if train_proportion is not None:

                    iter_loss_val = self.step_eval_loss(
                        name, x_data_val, extra_data_val)

                    if tracking:
                        guide_tr, model_tr = self.step_trace(
                            name, x_data_val, extra_data_val)
                        self.logp_hist_val[name]['guide'].append(
                            guide_tr.log_prob_sum().item())
                        self.logp_hist_val[name]['model'].append(
                            model_tr.log_prob_sum().item())

                        for k, v in model_tr.nodes.items():
                            if "log_prob_sum" in v:
                                self.logp_hist_val[name][k].append(
                                    v["log_prob_sum"].item())

                    self.validation_hist[name].append(iter_loss_val)
                    epochs_iterator.set_description(f'ELBO Loss: ' + '{:.4e}'.format(iter_loss) \
                                                    + ': Val loss: ' + '{:.4e}'.format(iter_loss_val))
                else:
                    epochs_iterator.set_description('ELBO Loss: ' +
                                                    '{:.4e}'.format(iter_loss))

                if epoch % 20 == 0:
                    torch.cuda.empty_cache()

            if train_proportion is not None:
                # rescale loss
                self.validation_hist[name] = [
                    i / (1 - train_proportion)
                    for i in self.validation_hist[name]
                ]
                self.hist[name] = [
                    i / train_proportion for i in self.hist[name]
                ]

                # reassing the main loss to be displayed
                self.training_hist[name] = self.hist[name]
                self.hist[name] = self.validation_hist[name]

                if tracking:
                    for k, v in self.logp_hist[name].items():
                        self.logp_hist[name][k] = [
                            i / train_proportion
                            for i in self.logp_hist[name][k]
                        ]
                        self.logp_hist_val[name][k] = [
                            i / (1 - train_proportion)
                            for i in self.logp_hist_val[name][k]
                        ]

                    self.logp_hist_train[name] = self.logp_hist[name]
                    self.logp_hist[name] = self.logp_hist_val[name]

            if self.verbose:
                print(plt.plot(np.log10(self.hist[name][0:])))
Beispiel #11
0
    def fit_advi_iterative_simple(
        self,
        n: int = 3,
        method='advi',
        n_type='restart',
        n_iter=None,
        learning_rate=None,
        progressbar=True,
    ):
        r""" Find posterior using ADVI (deprecated)
        (maximising likehood of the data and minimising KL-divergence of posterior to prior)
        :param n: number of independent initialisations
        :param method: which approximation of the posterior (guide) to use?.
            * ``'advi'`` - Univariate normal approximation (pyro.infer.autoguide.AutoDiagonalNormal)
            * ``'custom'`` - Custom guide using conjugate posteriors
        :return: self.svi dictionary with svi pyro objects for each n, and sefl.elbo dictionary storing training history. 
        """

        # Pass data to pyro / pytorch
        self.x_data = torch.tensor(self.X_data.astype(
            self.data_type))  # .double()

        # initialise parameter store
        self.svi = {}
        self.hist = {}
        self.guide_i = {}
        self.samples = {}
        self.node_samples = {}

        self.n_type = n_type

        if n_iter is None:
            n_iter = self.n_iter

        if learning_rate is None:
            learning_rate = self.learning_rate

        if np.isin(n_type, ['bootstrap']):
            if self.X_data_sample is None:
                self.bootstrap_data(n=n)
        elif np.isin(n_type, ['cv']):
            self.generate_cv_data()  # cv data added to self.X_data_sample

        init_names = ['init_' + str(i + 1) for i in np.arange(n)]

        for i, name in enumerate(init_names):

            # initialise Variational distributiion = guide
            if method is 'advi':
                self.guide_i[name] = AutoGuideList(self.model)
                self.guide_i[name].append(
                    AutoNormal(poutine.block(self.model,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=self.point_estim),
                               init_loc_fn=init_to_mean))
                self.guide_i[name].append(
                    AutoDelta(
                        poutine.block(self.model,
                                      hide_all=True,
                                      expose=self.point_estim)))
            elif method is 'custom':
                self.guide_i[name] = self.guide

            # initialise SVI inference method
            self.svi[name] = SVI(
                self.model,
                self.guide_i[name],
                optim.ClippedAdam({
                    'lr': learning_rate,
                    # limit the gradient step from becoming too large
                    'clip_norm': self.total_grad_norm_constraint
                }),
                loss=JitTrace_ELBO())

            pyro.clear_param_store()

            # record ELBO Loss history here
            self.hist[name] = []

            # pick dataset depending on the training mode and move to GPU
            if np.isin(n_type, ['cv', 'bootstrap']):
                self.x_data = torch.tensor(self.X_data_sample[i].astype(
                    self.data_type))
            else:
                self.x_data = torch.tensor(self.X_data.astype(self.data_type))

            if self.use_cuda:
                # move tensors and modules to CUDA
                self.x_data = self.x_data.cuda()

            # train for n_iter
            it_iterator = tqdm(range(n_iter))
            for it in it_iterator:

                hist = self.svi[name].step(self.x_data)
                it_iterator.set_description('ELBO Loss: ' +
                                            str(np.round(hist, 3)))
                self.hist[name].append(hist)

                # if it % 50 == 0 & self.verbose:
                # logging.info("Elbo loss: {}".format(hist))
                if it % 500 == 0:
                    torch.cuda.empty_cache()
Beispiel #12
0
    def fit_svi(self, *,
                num_samples=100,
                num_steps=2000,
                num_particles=32,
                learning_rate=0.1,
                learning_rate_decay=0.01,
                betas=(0.8, 0.99),
                haar=True,
                init_scale=0.01,
                guide_rank=0,
                jit=False,
                log_every=200,
                **options):
        """
        Runs stochastic variational inference to generate posterior samples.

        This runs :class:`~pyro.infer.svi.SVI`, setting the ``.samples``
        attribute on completion.

        This approximate inference method is useful for quickly iterating on
        probabilistic models.

        :param int num_samples: Number of posterior samples to draw from the
            trained guide. Defaults to 100.
        :param int num_steps: Number of :class:`~pyro.infer.svi.SVI` steps.
        :param int num_particles: Number of :class:`~pyro.infer.svi.SVI`
            particles per step.
        :param int learning_rate: Learning rate for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
        :param int learning_rate_decay: Learning rate for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer. Note this
            is decay over the entire schedule, not per-step decay.
        :param tuple betas: Momentum parameters for the
            :class:`~pyro.optim.clipped_adam.ClippedAdam` optimizer.
        :param bool haar: Whether to use a Haar wavelet reparameterizer.
        :param int guide_rank: Rank of the auto normal guide. If zero (default)
            use an :class:`~pyro.infer.autoguide.AutoNormal` guide. If a
            positive integer or None, use an
            :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
            If the string "full", use an
            :class:`~pyro.infer.autoguide.AutoMultivariateNormal` guide. These
            latter two require more ``num_steps`` to fit.
        :param float init_scale: Initial scale of the
            :class:`~pyro.infer.autoguide.AutoLowRankMultivariateNormal` guide.
        :param bool jit: Whether to use a jit compiled ELBO.
        :param int log_every: How often to log svi losses.
        :param int heuristic_num_particles: Passed to :meth:`heuristic` as
            ``num_particles``. Defaults to 1024.
        :returns: Time series of SVI losses (useful to diagnose convergence).
        :rtype: list
        """
        # Save configuration for .predict().
        self.relaxed = True
        self.num_quant_bins = 1

        # Setup Haar wavelet transform.
        if haar:
            time_dim = -2 if self.is_regional else -1
            dims = {"auxiliary": time_dim}
            supports = {"auxiliary": constraints.interval(-0.5, self.population + 0.5)}
            for name, (fn, is_regional) in self._non_compartmental.items():
                dims[name] = time_dim - fn.event_dim
                supports[name] = fn.support
            haar = _HaarSplitReparam(0, self.duration, dims, supports)

        # Heuristically initialize to feasible latents.
        heuristic_options = {k.replace("heuristic_", ""): options.pop(k)
                             for k in list(options)
                             if k.startswith("heuristic_")}
        assert not options, "unrecognized options: {}".format(", ".join(options))
        init_strategy = self._heuristic(haar, **heuristic_options)

        # Configure variational inference.
        logger.info("Running inference...")
        model = self._relaxed_model
        if haar:
            model = haar.reparam(model)
        if guide_rank == 0:
            guide = AutoNormal(model, init_loc_fn=init_strategy, init_scale=init_scale)
        elif guide_rank == "full":
            guide = AutoMultivariateNormal(model, init_loc_fn=init_strategy,
                                           init_scale=init_scale)
        elif guide_rank is None or isinstance(guide_rank, int):
            guide = AutoLowRankMultivariateNormal(model, init_loc_fn=init_strategy,
                                                  init_scale=init_scale, rank=guide_rank)
        else:
            raise ValueError("Invalid guide_rank: {}".format(guide_rank))
        Elbo = JitTrace_ELBO if jit else Trace_ELBO
        elbo = Elbo(max_plate_nesting=self.max_plate_nesting,
                    num_particles=num_particles, vectorize_particles=True,
                    ignore_jit_warnings=True)
        optim = ClippedAdam({"lr": learning_rate, "betas": betas,
                             "lrd": learning_rate_decay ** (1 / num_steps)})
        svi = SVI(model, guide, optim, elbo)

        # Run inference.
        start_time = default_timer()
        losses = []
        for step in range(1 + num_steps):
            loss = svi.step() / self.duration
            if step % log_every == 0:
                logger.info("step {} loss = {:0.4g}".format(step, loss))
            losses.append(loss)
        elapsed = default_timer() - start_time
        logger.info("SVI took {:0.1f} seconds, {:0.1f} step/sec"
                    .format(elapsed, (1 + num_steps) / elapsed))

        # Draw posterior samples.
        with torch.no_grad():
            particle_plate = pyro.plate("particles", num_samples,
                                        dim=-1 - self.max_plate_nesting)
            guide_trace = poutine.trace(particle_plate(guide)).get_trace()
            model_trace = poutine.trace(
                poutine.replay(particle_plate(model), guide_trace)).get_trace()
            self.samples = {name: site["value"] for name, site in model_trace.nodes.items()
                            if site["type"] == "sample"
                            if not site["is_observed"]
                            if not site_is_subsample(site)}
            if haar:
                haar.aux_to_user(self.samples)
        assert all(v.size(0) == num_samples for v in self.samples.values()), \
            {k: tuple(v.shape) for k, v in self.samples.items()}

        return losses
Beispiel #13
0
    def _create_autoguide(
        self,
        model,
        amortised,
        encoder_kwargs,
        data_transform,
        encoder_mode,
        init_loc_fn=init_to_mean,
        n_cat_list: list = [],
        encoder_instance=None,
    ):

        if not amortised:
            _guide = AutoNormal(
                model,
                init_loc_fn=init_loc_fn,
                create_plates=model.create_plates,
            )
        else:
            encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict()
            n_hidden = encoder_kwargs["n_hidden"] if "n_hidden" in encoder_kwargs.keys() else 200
            init_param_scale = (
                encoder_kwargs["init_param_scale"] if "init_param_scale" in encoder_kwargs.keys() else 1 / 50
            )
            if "init_param_scale" in encoder_kwargs.keys():
                del encoder_kwargs["init_param_scale"]
            amortised_vars = self.list_obs_plate_vars
            _guide = AutoGuideList(model, create_plates=model.create_plates)
            _guide.append(
                AutoNormal(
                    pyro.poutine.block(model, hide=list(amortised_vars["sites"].keys())),
                    init_loc_fn=init_loc_fn,
                )
            )
            if isinstance(data_transform, np.ndarray):
                # add extra info about gene clusters to the network
                self.register_buffer("gene_clusters", torch.tensor(data_transform.astype("float32")))
                n_in = model.n_vars + data_transform.shape[1]
                data_transform = self.data_transform_clusters()
            elif data_transform == "log1p":
                # use simple log1p transform
                data_transform = torch.log1p
                n_in = self.model.n_vars
            elif (
                isinstance(data_transform, dict)
                and "var_std" in list(data_transform.keys())
                and "var_mean" in list(data_transform.keys())
            ):
                # use data transform by scaling
                n_in = model.n_vars
                self.register_buffer(
                    "var_mean",
                    torch.tensor(data_transform["var_mean"].astype("float32").reshape((1, n_in))),
                )
                self.register_buffer(
                    "var_std",
                    torch.tensor(data_transform["var_std"].astype("float32").reshape((1, n_in))),
                )
                data_transform = self.data_transform_scale()
            else:
                # use custom data transform
                data_transform = data_transform
                n_in = model.n_vars

            if len(amortised_vars["input"]) >= 2:
                encoder_kwargs["n_cat_list"] = n_cat_list
            amortised_vars["input_transform"][0] = data_transform
            _guide.append(
                AutoNormalEncoder(
                    pyro.poutine.block(model, expose=list(amortised_vars["sites"].keys())),
                    amortised_plate_sites=amortised_vars,
                    n_in=n_in,
                    n_hidden=n_hidden,
                    init_param_scale=init_param_scale,
                    encoder_kwargs=encoder_kwargs,
                    encoder_mode=encoder_mode,
                    encoder_instance=encoder_instance,
                )
            )
        return _guide
Beispiel #14
0
# Load full unnormalised data
# X_train, y_train = torch.load('/Users/ricard/test/pyro/files/MNIST/processed/training.pt')
# X_test, y_test = torch.load('/Users/ricard/test/pyro/files/MNIST/processed/test.pt')

# Load data using torch.utils.data.DataLoader
train_loader, test_loader = setup_data_loaders(batch_size=512, subset=True)

# clear param store
pyro.clear_param_store()

# setup the Factor Analysis model
fa = FA()

# Define guide using automatic ELBO with mean-field assumption
guide = AutoNormal(fa)

# Defineg guide manually
# guide = fa.guide

optim = Adam({"lr": LEARNING_RATE})
svi = SVI(fa.forward, guide=guide, optim=optim, loss=Trace_ELBO())

# training loop
train_elbo = []
test_elbo = []
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader)
    train_elbo.append(-total_epoch_loss_train)
    # print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))