Example #1
0
def AutoMixed(model_full, init_loc={}, delta=None):
    guide = AutoGuideList(model_full)

    marginalised_guide_block = poutine.block(model_full,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=['tau'])
    if delta is None:
        guide.append(
            AutoNormal(marginalised_guide_block,
                       init_loc_fn=autoguide.init_to_value(values=init_loc),
                       init_scale=0.05))
    elif delta == 'part' or delta == 'all':
        guide.append(
            AutoDelta(marginalised_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    full_rank_guide_block = poutine.block(model_full,
                                          hide_all=True,
                                          expose=['tau'])
    if delta is None or delta == 'part':
        guide.append(
            AutoMultivariateNormal(
                full_rank_guide_block,
                init_loc_fn=autoguide.init_to_value(values=init_loc),
                init_scale=0.05))
    else:
        guide.append(
            AutoDelta(full_rank_guide_block,
                      init_loc_fn=autoguide.init_to_value(values=init_loc)))

    return guide
Example #2
0
def test_guide_list(auto_class):

    def model():
        pyro.sample("x", dist.Normal(0., 1.).expand([2]))
        pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5)))

    guide = AutoGuideList(model)
    guide.append(auto_class(poutine.block(model, expose=["x"])))
    guide.append(auto_class(poutine.block(model, expose=["y"])))
    guide()
Example #3
0
def nested_auto_guide_callable(model):
    guide = AutoGuideList(model)
    guide.append(AutoDelta(poutine.block(model, expose=['x'])))
    guide_y = AutoGuideList(poutine.block(model, expose=['y']))
    guide_y.z = AutoIAFNormal(poutine.block(model, expose=['y']))
    guide.append(guide_y)
    return guide
Example #4
0
def auto_guide_module_callable(model):
    class GuideX(AutoGuide):
        def __init__(self, model):
            super().__init__(model)
            self.x_loc = nn.Parameter(torch.tensor(1.))
            self.x_scale = PyroParam(torch.tensor(.1), constraint=constraints.positive)

        def forward(self, *args, **kwargs):
            return {"x": pyro.sample("x", dist.Normal(self.x_loc, self.x_scale))}

        def median(self, *args, **kwargs):
            return {"x": self.x_loc.detach()}

    guide = AutoGuideList(model)
    guide.custom = GuideX(model)
    guide.diagnorm = AutoDiagonalNormal(poutine.block(model, hide=["x"]))
    return guide
Example #5
0
def auto_guide_callable(model):
    def guide_x():
        x_loc = pyro.param("x_loc", torch.tensor(1.))
        x_scale = pyro.param("x_scale", torch.tensor(.1), constraint=constraints.positive)
        pyro.sample("x", dist.Normal(x_loc, x_scale))

    def median_x():
        return {"x": pyro.param("x_loc", torch.tensor(1.))}

    guide = AutoGuideList(model)
    guide.append(AutoCallable(model, guide_x, median_x))
    guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"])))
    return guide
Example #6
0
def test_callable(auto_class):

    def model():
        pyro.sample("x", dist.Normal(0., 1.))
        pyro.sample("y", dist.MultivariateNormal(torch.zeros(5), torch.eye(5, 5)))

    def guide_x():
        x_loc = pyro.param("x_loc", torch.tensor(0.))
        pyro.sample("x", dist.Delta(x_loc))

    guide = AutoGuideList(model)
    guide.append(guide_x)
    guide.append(auto_class(poutine.block(model, expose=["y"])))
    values = guide()
    assert set(values) == set(["y"])
Example #7
0
def test_subsample_guide(auto_class, init_fn):

    # The model from tutorial/source/easyguide.ipynb
    def model(batch, subsample, full_size):
        num_time_steps = len(batch)
        result = [None] * num_time_steps
        drift = pyro.sample("drift", dist.LogNormal(-1, 0.5))
        plate = pyro.plate("data", full_size, subsample=subsample)
        assert plate.size == 50
        with plate:
            z = 0.
            for t in range(num_time_steps):
                z = pyro.sample("state_{}".format(t), dist.Normal(z, drift))
                result[t] = pyro.sample("obs_{}".format(t),
                                        dist.Bernoulli(logits=z),
                                        obs=batch[t])

        return torch.stack(result)

    def create_plates(batch, subsample, full_size):
        return pyro.plate("data", full_size, subsample=subsample)

    if auto_class == AutoGuideList:
        guide = AutoGuideList(model, create_plates=create_plates)
        guide.add(AutoDelta(poutine.block(model, expose=["drift"])))
        guide.add(AutoNormal(poutine.block(model, hide=["drift"])))
    else:
        guide = auto_class(model, create_plates=create_plates)

    full_size = 50
    batch_size = 20
    num_time_steps = 8
    pyro.set_rng_seed(123456789)
    data = model([None] * num_time_steps, torch.arange(full_size), full_size)
    assert data.shape == (num_time_steps, full_size)

    pyro.get_param_store().clear()
    pyro.set_rng_seed(123456789)
    svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
    for epoch in range(2):
        beg = 0
        while beg < full_size:
            end = min(full_size, beg + batch_size)
            subsample = torch.arange(beg, end)
            batch = data[:, beg:end]
            beg = end
            svi.step(batch, subsample, full_size=full_size)
Example #8
0
def test_discrete_parallel(continuous_class):
    K = 2
    data = torch.tensor([0., 1., 10., 11., 12.])

    def model(data):
        weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
        locs = pyro.sample('locs', dist.Normal(0, 10).expand_by([K]).to_event(1))
        scale = pyro.sample('scale', dist.LogNormal(0, 1))

        with pyro.plate('data', len(data)):
            weights = weights.expand(torch.Size((len(data),)) + weights.shape)
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

    guide = AutoGuideList(model)
    guide.append(continuous_class(poutine.block(model, hide=["assignment"])))
    guide.append(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))

    elbo = TraceEnum_ELBO(max_plate_nesting=1)
    loss = elbo.loss_and_grads(model, guide, data)
    assert np.isfinite(loss), loss
Example #9
0
def auto_guide_list_x(model):
    guide = AutoGuideList(model)
    guide.append(AutoDelta(poutine.block(model, expose=["x"])))
    guide.append(AutoDiagonalNormal(poutine.block(model, hide=["x"])))
    return guide
Example #10
0
    def fit_advi_iterative(self,
                           n=3,
                           method='advi',
                           n_type='restart',
                           n_iter=None,
                           learning_rate=None,
                           progressbar=True,
                           num_workers=2,
                           train_proportion=None,
                           stratify_cv=None,
                           l2_weight=False,
                           sample_scaling_weight=0.5,
                           checkpoints=None,
                           checkpoint_dir='./checkpoints',
                           tracking=False):
        r""" Train posterior using ADVI method.
        (maximising likehood of the data and minimising KL-divergence of posterior to prior)
        :param n: number of independent initialisations
        :param method: to allow for potential use of SVGD or MCMC (currently only ADVI implemented).
        :param n_type: type of repeated initialisation:
                                  'restart' to pick different initial value,
                                  'cv' for molecular cross-validation - splits counts into n datasets,
                                         for now, only n=2 is implemented
                                  'bootstrap' for fitting the model to multiple downsampled datasets.
                                         Run `mod.bootstrap_data()` to generate variants of data
        :param n_iter: number of iterations, supersedes self.n_iter
        :param train_proportion: if not None, which proportion of cells to use for training and which for validation.
        :param checkpoints: int, list of int's or None, number of checkpoints to save while model training or list of
            iterations to save checkpoints on
        :param checkpoint_dir: str, directory to save checkpoints in
        :param tracking: bool, track all latent variables during training - if True makes training 2 times slower
        :return: None
        """

        # initialise parameter store
        self.svi = {}
        self.hist = {}
        self.guide_i = {}
        self.samples = {}
        self.node_samples = {}

        if tracking:
            self.logp_hist = {}

        if n_iter is None:
            n_iter = self.n_iter

        if type(checkpoints) is int:
            if n_iter < checkpoints:
                checkpoints = n_iter
            checkpoints = np.linspace(0, n_iter, checkpoints + 1,
                                      dtype=int)[1:]
            self.checkpoints = list(checkpoints)
        else:
            self.checkpoints = checkpoints

        self.checkpoint_dir = checkpoint_dir

        self.n_type = n_type
        self.l2_weight = l2_weight
        self.sample_scaling_weight = sample_scaling_weight
        self.train_proportion = train_proportion

        if stratify_cv is not None:
            self.stratify_cv = stratify_cv

        if train_proportion is not None:
            self.validation_hist = {}
            self.training_hist = {}
            if tracking:
                self.logp_hist_val = {}
                self.logp_hist_train = {}

        if learning_rate is None:
            learning_rate = self.learning_rate

        if np.isin(n_type, ['bootstrap']):
            if self.X_data_sample is None:
                self.bootstrap_data(n=n)
        elif np.isin(n_type, ['cv']):
            self.generate_cv_data()  # cv data added to self.X_data_sample

        init_names = ['init_' + str(i + 1) for i in np.arange(n)]

        for i, name in enumerate(init_names):
            ################### Initialise parameters & optimiser ###################
            # initialise Variational distribution = guide
            if method is 'advi':
                self.guide_i[name] = AutoGuideList(self.model)
                normal_guide_block = poutine.block(
                    self.model,
                    expose_all=True,
                    hide_all=False,
                    hide=self.point_estim +
                    flatten_iterable(self.custom_guides.keys()))
                self.guide_i[name].append(
                    AutoNormal(normal_guide_block, init_loc_fn=init_to_mean))
                self.guide_i[name].append(
                    AutoDelta(
                        poutine.block(self.model,
                                      hide_all=True,
                                      expose=self.point_estim)))
                for k, v in self.custom_guides.items():
                    self.guide_i[name].append(v)

            elif method is 'custom':
                self.guide_i[name] = self.guide

            # initialise SVI inference method
            self.svi[name] = SVI(
                self.model,
                self.guide_i[name],
                optim.ClippedAdam({
                    'lr': learning_rate,
                    # limit the gradient step from becoming too large
                    'clip_norm': self.total_grad_norm_constraint
                }),
                loss=JitTrace_ELBO())

            pyro.clear_param_store()

            self.set_initial_values()

            # record ELBO Loss history here
            self.hist[name] = []
            if tracking:
                self.logp_hist[name] = defaultdict(list)

            if train_proportion is not None:
                self.validation_hist[name] = []
                if tracking:
                    self.logp_hist_val[name] = defaultdict(list)

            ################### Select data for this iteration ###################
            if np.isin(n_type, ['cv', 'bootstrap']):
                X_data = self.X_data_sample[i].astype(self.data_type)
            else:
                X_data = self.X_data.astype(self.data_type)

            ################### Training / validation split ###################
            # split into training and validation
            if train_proportion is not None:
                idx = np.arange(len(X_data))
                train_idx, val_idx = train_test_split(
                    idx,
                    train_size=train_proportion,
                    shuffle=True,
                    stratify=self.stratify_cv)

                extra_data_val = {
                    k: torch.FloatTensor(v[val_idx]).to(self.device)
                    for k, v in self.extra_data.items()
                }
                extra_data_train = {
                    k: torch.FloatTensor(v[train_idx])
                    for k, v in self.extra_data.items()
                }

                x_data_val = torch.FloatTensor(X_data[val_idx]).to(self.device)
                x_data = torch.FloatTensor(X_data[train_idx])
            else:
                # just convert data to CPU tensors
                x_data = torch.FloatTensor(X_data)
                extra_data_train = {
                    k: torch.FloatTensor(v)
                    for k, v in self.extra_data.items()
                }

            ################### Move data to cuda - FULL data ###################
            # if not minibatch do this:
            if self.minibatch_size is None:
                # move tensors to CUDA
                x_data = x_data.to(self.device)
                for k in extra_data_train.keys():
                    extra_data_train[k] = extra_data_train[k].to(self.device)
                # extra_data_train = {k: v.to(self.device) for k, v in extra_data_train.items()}

            ################### MINIBATCH data ###################
            else:
                # create minibatches
                dataset = MiniBatchDataset(x_data,
                                           extra_data_train,
                                           return_idx=True)
                loader = DataLoader(dataset,
                                    batch_size=self.minibatch_size,
                                    num_workers=0)  # TODO num_workers

            ################### Training the model ###################
            # start training in epochs
            epochs_iterator = tqdm(range(n_iter))
            for epoch in epochs_iterator:

                if self.minibatch_size is None:
                    ################### Training FULL data ###################
                    iter_loss = self.step_train(name, x_data, extra_data_train)

                    self.hist[name].append(iter_loss)
                    # save data for posterior sampling
                    self.x_data = x_data
                    self.extra_data_train = extra_data_train

                    if tracking:
                        guide_tr, model_tr = self.step_trace(
                            name, x_data, extra_data_train)
                        self.logp_hist[name]['guide'].append(
                            guide_tr.log_prob_sum().item())
                        self.logp_hist[name]['model'].append(
                            model_tr.log_prob_sum().item())

                        for k, v in model_tr.nodes.items():
                            if "log_prob_sum" in v:
                                self.logp_hist[name][k].append(
                                    v["log_prob_sum"].item())

                else:
                    ################### Training MINIBATCH data ###################
                    aver_loss = []
                    if tracking:
                        aver_logp_guide = []
                        aver_logp_model = []
                        aver_logp = defaultdict(list)

                    for batch in loader:

                        x_data_batch, extra_data_batch = batch
                        x_data_batch = x_data_batch.to(self.device)
                        extra_data_batch = {
                            k: v.to(self.device)
                            for k, v in extra_data_batch.items()
                        }

                        loss = self.step_train(name, x_data_batch,
                                               extra_data_batch)

                        if tracking:
                            guide_tr, model_tr = self.step_trace(
                                name, x_data_batch, extra_data_batch)
                            aver_logp_guide.append(
                                guide_tr.log_prob_sum().item())
                            aver_logp_model.append(
                                model_tr.log_prob_sum().item())

                            for k, v in model_tr.nodes.items():
                                if "log_prob_sum" in v:
                                    aver_logp[k].append(
                                        v["log_prob_sum"].item())

                        aver_loss.append(loss)

                    iter_loss = np.sum(aver_loss)

                    # save data for posterior sampling
                    self.x_data = x_data_batch
                    self.extra_data_train = extra_data_batch

                    self.hist[name].append(iter_loss)

                    if tracking:
                        iter_logp_guide = np.sum(aver_logp_guide)
                        iter_logp_model = np.sum(aver_logp_model)
                        self.logp_hist[name]['guide'].append(iter_logp_guide)
                        self.logp_hist[name]['model'].append(iter_logp_model)

                        for k, v in aver_logp.items():
                            self.logp_hist[name][k].append(np.sum(v))

                if self.checkpoints is not None:
                    if (epoch + 1) in self.checkpoints:
                        self.save_checkpoint(epoch + 1, prefix=name)

                ################### Evaluating cross-validation loss ###################
                if train_proportion is not None:

                    iter_loss_val = self.step_eval_loss(
                        name, x_data_val, extra_data_val)

                    if tracking:
                        guide_tr, model_tr = self.step_trace(
                            name, x_data_val, extra_data_val)
                        self.logp_hist_val[name]['guide'].append(
                            guide_tr.log_prob_sum().item())
                        self.logp_hist_val[name]['model'].append(
                            model_tr.log_prob_sum().item())

                        for k, v in model_tr.nodes.items():
                            if "log_prob_sum" in v:
                                self.logp_hist_val[name][k].append(
                                    v["log_prob_sum"].item())

                    self.validation_hist[name].append(iter_loss_val)
                    epochs_iterator.set_description(f'ELBO Loss: ' + '{:.4e}'.format(iter_loss) \
                                                    + ': Val loss: ' + '{:.4e}'.format(iter_loss_val))
                else:
                    epochs_iterator.set_description('ELBO Loss: ' +
                                                    '{:.4e}'.format(iter_loss))

                if epoch % 20 == 0:
                    torch.cuda.empty_cache()

            if train_proportion is not None:
                # rescale loss
                self.validation_hist[name] = [
                    i / (1 - train_proportion)
                    for i in self.validation_hist[name]
                ]
                self.hist[name] = [
                    i / train_proportion for i in self.hist[name]
                ]

                # reassing the main loss to be displayed
                self.training_hist[name] = self.hist[name]
                self.hist[name] = self.validation_hist[name]

                if tracking:
                    for k, v in self.logp_hist[name].items():
                        self.logp_hist[name][k] = [
                            i / train_proportion
                            for i in self.logp_hist[name][k]
                        ]
                        self.logp_hist_val[name][k] = [
                            i / (1 - train_proportion)
                            for i in self.logp_hist_val[name][k]
                        ]

                    self.logp_hist_train[name] = self.logp_hist[name]
                    self.logp_hist[name] = self.logp_hist_val[name]

            if self.verbose:
                print(plt.plot(np.log10(self.hist[name][0:])))
Example #11
0
    def fit_advi_iterative_simple(
        self,
        n: int = 3,
        method='advi',
        n_type='restart',
        n_iter=None,
        learning_rate=None,
        progressbar=True,
    ):
        r""" Find posterior using ADVI (deprecated)
        (maximising likehood of the data and minimising KL-divergence of posterior to prior)
        :param n: number of independent initialisations
        :param method: which approximation of the posterior (guide) to use?.
            * ``'advi'`` - Univariate normal approximation (pyro.infer.autoguide.AutoDiagonalNormal)
            * ``'custom'`` - Custom guide using conjugate posteriors
        :return: self.svi dictionary with svi pyro objects for each n, and sefl.elbo dictionary storing training history. 
        """

        # Pass data to pyro / pytorch
        self.x_data = torch.tensor(self.X_data.astype(
            self.data_type))  # .double()

        # initialise parameter store
        self.svi = {}
        self.hist = {}
        self.guide_i = {}
        self.samples = {}
        self.node_samples = {}

        self.n_type = n_type

        if n_iter is None:
            n_iter = self.n_iter

        if learning_rate is None:
            learning_rate = self.learning_rate

        if np.isin(n_type, ['bootstrap']):
            if self.X_data_sample is None:
                self.bootstrap_data(n=n)
        elif np.isin(n_type, ['cv']):
            self.generate_cv_data()  # cv data added to self.X_data_sample

        init_names = ['init_' + str(i + 1) for i in np.arange(n)]

        for i, name in enumerate(init_names):

            # initialise Variational distributiion = guide
            if method is 'advi':
                self.guide_i[name] = AutoGuideList(self.model)
                self.guide_i[name].append(
                    AutoNormal(poutine.block(self.model,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=self.point_estim),
                               init_loc_fn=init_to_mean))
                self.guide_i[name].append(
                    AutoDelta(
                        poutine.block(self.model,
                                      hide_all=True,
                                      expose=self.point_estim)))
            elif method is 'custom':
                self.guide_i[name] = self.guide

            # initialise SVI inference method
            self.svi[name] = SVI(
                self.model,
                self.guide_i[name],
                optim.ClippedAdam({
                    'lr': learning_rate,
                    # limit the gradient step from becoming too large
                    'clip_norm': self.total_grad_norm_constraint
                }),
                loss=JitTrace_ELBO())

            pyro.clear_param_store()

            # record ELBO Loss history here
            self.hist[name] = []

            # pick dataset depending on the training mode and move to GPU
            if np.isin(n_type, ['cv', 'bootstrap']):
                self.x_data = torch.tensor(self.X_data_sample[i].astype(
                    self.data_type))
            else:
                self.x_data = torch.tensor(self.X_data.astype(self.data_type))

            if self.use_cuda:
                # move tensors and modules to CUDA
                self.x_data = self.x_data.cuda()

            # train for n_iter
            it_iterator = tqdm(range(n_iter))
            for it in it_iterator:

                hist = self.svi[name].step(self.x_data)
                it_iterator.set_description('ELBO Loss: ' +
                                            str(np.round(hist, 3)))
                self.hist[name].append(hist)

                # if it % 50 == 0 & self.verbose:
                # logging.info("Elbo loss: {}".format(hist))
                if it % 500 == 0:
                    torch.cuda.empty_cache()