Beispiel #1
0
def test_auto_diagonal_gaussians(auto_class, Elbo):
    n_steps = 3001

    def model():
        pyro.sample("x", dist.Normal(-0.2, 1.2))
        pyro.sample("y", dist.Normal(0.2, 0.7))

    if auto_class is AutoLowRankMultivariateNormal:
        guide = auto_class(model, rank=1)
    else:
        guide = auto_class(model)
    adam = optim.ClippedAdam(
        {"lr": 0.01, "betas": (0.95, 0.999), "lrd": 0.1 ** (1 / n_steps)}
    )
    svi = SVI(model, guide, adam, loss=Elbo())

    for k in range(n_steps):
        loss = svi.step()
        assert np.isfinite(loss), loss

    if auto_class is AutoLaplaceApproximation:
        guide = guide.laplace_approximation()

    loc, scale = guide._loc_scale()

    expected_loc = torch.tensor([-0.2, 0.2])
    assert_equal(
        loc.detach(),
        expected_loc,
        prec=0.05,
        msg="\n".join(
            [
                "Incorrect guide loc. Expected:",
                str(expected_loc.cpu().numpy()),
                "Actual:",
                str(loc.detach().cpu().numpy()),
            ]
        ),
    )
    expected_scale = torch.tensor([1.2, 0.7])
    assert_equal(
        scale.detach(),
        expected_scale,
        prec=0.08,
        msg="\n".join(
            [
                "Incorrect guide scale. Expected:",
                str(expected_scale.cpu().numpy()),
                "Actual:",
                str(scale.detach().cpu().numpy()),
            ]
        ),
    )
Beispiel #2
0
            def initialise_svi(x_data, extra_data):

                pyro.clear_param_store()

                self.set_initial_values()

                self.init_guide(name, x_data, extra_data)

                self.trace_elbo_i[name] = JitTrace_ELBO()  # JitTrace_ELBO()

                # initialise SVI inference method
                self.svi[name] = SVI(
                    self.model,
                    self.guide_i[name],
                    optim.ClippedAdam({
                        'lr':
                        learning_rate,
                        # limit the gradient step from becoming too large
                        'clip_norm':
                        self.total_grad_norm_constraint
                    }),
                    loss=self.trace_elbo_i[name])
Beispiel #3
0
    def _train_full_data(self, x_data, obs2sample, n_epochs=20000, lr=0.002):

        idx = np.arange(x_data.shape[0]).astype("int64")

        device = torch.device("cuda")
        idx = torch.tensor(idx).to(device)
        x_data = torch.tensor(x_data).to(device)
        obs2sample = torch.tensor(obs2sample).to(device)

        self.to(device)

        pyro.clear_param_store()
        self.guide(x_data, idx, obs2sample)

        svi = SVI(
            self.model,
            self.guide,
            optim.ClippedAdam({
                "lr": lr,
                "clip_norm": 200
            }),
            loss=Trace_ELBO(),
        )

        iter_iterator = tqdm(range(n_epochs))
        hist = []
        for it in iter_iterator:

            loss = svi.step(x_data, idx, obs2sample)
            iter_iterator.set_description("Epoch " + "{:d}".format(it) +
                                          ", -ELBO: " + "{:.4e}".format(loss))
            hist.append(loss)

            if it % 500 == 0:
                torch.cuda.empty_cache()

        self.hist = hist
Beispiel #4
0
def train_gp(args, dataset, gp_class):
    u, y = dataset.get_train_data(
        0, gp_class.name) if args.nclt else dataset.get_test_data(
            1, gp_class.name)  # this is only to have a correct dimension

    if gp_class.name == 'GpOdoFog':
        fnet = FNET(args, u.shape[2], args.kernel_dim)

        def fnet_fn(x):
            return pyro.module("FNET", fnet)(x)

        lik = gp.likelihoods.Gaussian(name='lik_f',
                                      variance=0.1 * torch.ones(6, 1))
        # lik = MultiVariateGaussian(name='lik_f', dim=6) # if lower_triangular_constraint is implemented
        kernel = gp.kernels.Matern52(
            input_dim=args.kernel_dim,
            lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=fnet_fn)
        Xu = u[torch.arange(0,
                            u.shape[0],
                            step=int(u.shape[0] /
                                     args.num_inducing_point)).long()]
        gp_model = gp.models.VariationalSparseGP(u,
                                                 torch.zeros(6, u.shape[0]),
                                                 kernel,
                                                 Xu,
                                                 num_data=dataset.num_data,
                                                 likelihood=lik,
                                                 mean_function=None,
                                                 name=gp_class.name,
                                                 whiten=True,
                                                 jitter=1e-3)
    else:
        hnet = HNET(args, u.shape[2], args.kernel_dim)

        def hnet_fn(x):
            return pyro.module("HNET", hnet)(x)

        lik = gp.likelihoods.Gaussian(name='lik_h',
                                      variance=0.1 * torch.ones(9, 1))
        # lik = MultiVariateGaussian(name='lik_h', dim=9) # if lower_triangular_constraint is implemented
        kernel = gp.kernels.Matern52(
            input_dim=args.kernel_dim,
            lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=hnet_fn)
        Xu = u[torch.arange(0,
                            u.shape[0],
                            step=int(u.shape[0] /
                                     args.num_inducing_point)).long()]
        gp_model = gp.models.VariationalSparseGP(u,
                                                 torch.zeros(9, u.shape[0]),
                                                 kernel,
                                                 Xu,
                                                 num_data=dataset.num_data,
                                                 likelihood=lik,
                                                 mean_function=None,
                                                 name=gp_class.name,
                                                 whiten=True,
                                                 jitter=1e-4)

    gp_instante = gp_class(args, gp_model, dataset)
    args.mate = preprocessing(args, dataset, gp_instante)

    optimizer = optim.ClippedAdam({"lr": args.lr, "lrd": args.lr_decay})
    svi = infer.SVI(gp_instante.model, gp_instante.guide, optimizer,
                    infer.Trace_ELBO())

    print("Start of training " + dataset.name + ", " + gp_class.name)
    start_time = time.time()
    for epoch in range(1, args.epochs + 1):
        train_loop(dataset, gp_instante, svi, epoch)
        if epoch == 10:
            if gp_class.name == 'GpOdoFog':
                gp_instante.gp_f.jitter = 1e-4
            else:
                gp_instante.gp_h.jitter = 1e-4

    save_gp(args, gp_instante,
            fnet) if gp_class.name == 'GpOdoFog' else save_gp(
                args, gp_instante, hnet)

def per_param_args(module_name, param_name):
    if '_loc' in param_name:
        return {"lr": 0.005}
    elif '_scale' in param_name:
        return {"lr": 0.005}  # CHANGED
    else:
        return {"lr": 0.005}


# In[16]:

svi = SVI(model,
          guide,
          optim.ClippedAdam(per_param_args),
          loss=Trace_ELBO(),
          num_samples=1000)
pyro.clear_param_store()

num_epochs = 10000
elbo_losses = []
alpha_errors = []
beta_errors = []
betaInd_errors = []
track_loglik = True
best_elbo = np.inf
patience_thre = 3
patience_count = 0
tic = time.time()
for j in range(num_epochs):
Beispiel #6
0
    def fit_advi_iterative(self,
                           n=3,
                           method='advi',
                           n_type='restart',
                           n_iter=None,
                           learning_rate=None,
                           progressbar=True,
                           num_workers=2,
                           train_proportion=None,
                           stratify_cv=None,
                           l2_weight=False,
                           sample_scaling_weight=0.5,
                           checkpoints=None,
                           checkpoint_dir='./checkpoints',
                           tracking=False):
        r""" Train posterior using ADVI method.
        (maximising likehood of the data and minimising KL-divergence of posterior to prior)
        :param n: number of independent initialisations
        :param method: to allow for potential use of SVGD or MCMC (currently only ADVI implemented).
        :param n_type: type of repeated initialisation:
                                  'restart' to pick different initial value,
                                  'cv' for molecular cross-validation - splits counts into n datasets,
                                         for now, only n=2 is implemented
                                  'bootstrap' for fitting the model to multiple downsampled datasets.
                                         Run `mod.bootstrap_data()` to generate variants of data
        :param n_iter: number of iterations, supersedes self.n_iter
        :param train_proportion: if not None, which proportion of cells to use for training and which for validation.
        :param checkpoints: int, list of int's or None, number of checkpoints to save while model training or list of
            iterations to save checkpoints on
        :param checkpoint_dir: str, directory to save checkpoints in
        :param tracking: bool, track all latent variables during training - if True makes training 2 times slower
        :return: None
        """

        # initialise parameter store
        self.svi = {}
        self.hist = {}
        self.guide_i = {}
        self.samples = {}
        self.node_samples = {}

        if tracking:
            self.logp_hist = {}

        if n_iter is None:
            n_iter = self.n_iter

        if type(checkpoints) is int:
            if n_iter < checkpoints:
                checkpoints = n_iter
            checkpoints = np.linspace(0, n_iter, checkpoints + 1,
                                      dtype=int)[1:]
            self.checkpoints = list(checkpoints)
        else:
            self.checkpoints = checkpoints

        self.checkpoint_dir = checkpoint_dir

        self.n_type = n_type
        self.l2_weight = l2_weight
        self.sample_scaling_weight = sample_scaling_weight
        self.train_proportion = train_proportion

        if stratify_cv is not None:
            self.stratify_cv = stratify_cv

        if train_proportion is not None:
            self.validation_hist = {}
            self.training_hist = {}
            if tracking:
                self.logp_hist_val = {}
                self.logp_hist_train = {}

        if learning_rate is None:
            learning_rate = self.learning_rate

        if np.isin(n_type, ['bootstrap']):
            if self.X_data_sample is None:
                self.bootstrap_data(n=n)
        elif np.isin(n_type, ['cv']):
            self.generate_cv_data()  # cv data added to self.X_data_sample

        init_names = ['init_' + str(i + 1) for i in np.arange(n)]

        for i, name in enumerate(init_names):
            ################### Initialise parameters & optimiser ###################
            # initialise Variational distribution = guide
            if method is 'advi':
                self.guide_i[name] = AutoGuideList(self.model)
                normal_guide_block = poutine.block(
                    self.model,
                    expose_all=True,
                    hide_all=False,
                    hide=self.point_estim +
                    flatten_iterable(self.custom_guides.keys()))
                self.guide_i[name].append(
                    AutoNormal(normal_guide_block, init_loc_fn=init_to_mean))
                self.guide_i[name].append(
                    AutoDelta(
                        poutine.block(self.model,
                                      hide_all=True,
                                      expose=self.point_estim)))
                for k, v in self.custom_guides.items():
                    self.guide_i[name].append(v)

            elif method is 'custom':
                self.guide_i[name] = self.guide

            # initialise SVI inference method
            self.svi[name] = SVI(
                self.model,
                self.guide_i[name],
                optim.ClippedAdam({
                    'lr': learning_rate,
                    # limit the gradient step from becoming too large
                    'clip_norm': self.total_grad_norm_constraint
                }),
                loss=JitTrace_ELBO())

            pyro.clear_param_store()

            self.set_initial_values()

            # record ELBO Loss history here
            self.hist[name] = []
            if tracking:
                self.logp_hist[name] = defaultdict(list)

            if train_proportion is not None:
                self.validation_hist[name] = []
                if tracking:
                    self.logp_hist_val[name] = defaultdict(list)

            ################### Select data for this iteration ###################
            if np.isin(n_type, ['cv', 'bootstrap']):
                X_data = self.X_data_sample[i].astype(self.data_type)
            else:
                X_data = self.X_data.astype(self.data_type)

            ################### Training / validation split ###################
            # split into training and validation
            if train_proportion is not None:
                idx = np.arange(len(X_data))
                train_idx, val_idx = train_test_split(
                    idx,
                    train_size=train_proportion,
                    shuffle=True,
                    stratify=self.stratify_cv)

                extra_data_val = {
                    k: torch.FloatTensor(v[val_idx]).to(self.device)
                    for k, v in self.extra_data.items()
                }
                extra_data_train = {
                    k: torch.FloatTensor(v[train_idx])
                    for k, v in self.extra_data.items()
                }

                x_data_val = torch.FloatTensor(X_data[val_idx]).to(self.device)
                x_data = torch.FloatTensor(X_data[train_idx])
            else:
                # just convert data to CPU tensors
                x_data = torch.FloatTensor(X_data)
                extra_data_train = {
                    k: torch.FloatTensor(v)
                    for k, v in self.extra_data.items()
                }

            ################### Move data to cuda - FULL data ###################
            # if not minibatch do this:
            if self.minibatch_size is None:
                # move tensors to CUDA
                x_data = x_data.to(self.device)
                for k in extra_data_train.keys():
                    extra_data_train[k] = extra_data_train[k].to(self.device)
                # extra_data_train = {k: v.to(self.device) for k, v in extra_data_train.items()}

            ################### MINIBATCH data ###################
            else:
                # create minibatches
                dataset = MiniBatchDataset(x_data,
                                           extra_data_train,
                                           return_idx=True)
                loader = DataLoader(dataset,
                                    batch_size=self.minibatch_size,
                                    num_workers=0)  # TODO num_workers

            ################### Training the model ###################
            # start training in epochs
            epochs_iterator = tqdm(range(n_iter))
            for epoch in epochs_iterator:

                if self.minibatch_size is None:
                    ################### Training FULL data ###################
                    iter_loss = self.step_train(name, x_data, extra_data_train)

                    self.hist[name].append(iter_loss)
                    # save data for posterior sampling
                    self.x_data = x_data
                    self.extra_data_train = extra_data_train

                    if tracking:
                        guide_tr, model_tr = self.step_trace(
                            name, x_data, extra_data_train)
                        self.logp_hist[name]['guide'].append(
                            guide_tr.log_prob_sum().item())
                        self.logp_hist[name]['model'].append(
                            model_tr.log_prob_sum().item())

                        for k, v in model_tr.nodes.items():
                            if "log_prob_sum" in v:
                                self.logp_hist[name][k].append(
                                    v["log_prob_sum"].item())

                else:
                    ################### Training MINIBATCH data ###################
                    aver_loss = []
                    if tracking:
                        aver_logp_guide = []
                        aver_logp_model = []
                        aver_logp = defaultdict(list)

                    for batch in loader:

                        x_data_batch, extra_data_batch = batch
                        x_data_batch = x_data_batch.to(self.device)
                        extra_data_batch = {
                            k: v.to(self.device)
                            for k, v in extra_data_batch.items()
                        }

                        loss = self.step_train(name, x_data_batch,
                                               extra_data_batch)

                        if tracking:
                            guide_tr, model_tr = self.step_trace(
                                name, x_data_batch, extra_data_batch)
                            aver_logp_guide.append(
                                guide_tr.log_prob_sum().item())
                            aver_logp_model.append(
                                model_tr.log_prob_sum().item())

                            for k, v in model_tr.nodes.items():
                                if "log_prob_sum" in v:
                                    aver_logp[k].append(
                                        v["log_prob_sum"].item())

                        aver_loss.append(loss)

                    iter_loss = np.sum(aver_loss)

                    # save data for posterior sampling
                    self.x_data = x_data_batch
                    self.extra_data_train = extra_data_batch

                    self.hist[name].append(iter_loss)

                    if tracking:
                        iter_logp_guide = np.sum(aver_logp_guide)
                        iter_logp_model = np.sum(aver_logp_model)
                        self.logp_hist[name]['guide'].append(iter_logp_guide)
                        self.logp_hist[name]['model'].append(iter_logp_model)

                        for k, v in aver_logp.items():
                            self.logp_hist[name][k].append(np.sum(v))

                if self.checkpoints is not None:
                    if (epoch + 1) in self.checkpoints:
                        self.save_checkpoint(epoch + 1, prefix=name)

                ################### Evaluating cross-validation loss ###################
                if train_proportion is not None:

                    iter_loss_val = self.step_eval_loss(
                        name, x_data_val, extra_data_val)

                    if tracking:
                        guide_tr, model_tr = self.step_trace(
                            name, x_data_val, extra_data_val)
                        self.logp_hist_val[name]['guide'].append(
                            guide_tr.log_prob_sum().item())
                        self.logp_hist_val[name]['model'].append(
                            model_tr.log_prob_sum().item())

                        for k, v in model_tr.nodes.items():
                            if "log_prob_sum" in v:
                                self.logp_hist_val[name][k].append(
                                    v["log_prob_sum"].item())

                    self.validation_hist[name].append(iter_loss_val)
                    epochs_iterator.set_description(f'ELBO Loss: ' + '{:.4e}'.format(iter_loss) \
                                                    + ': Val loss: ' + '{:.4e}'.format(iter_loss_val))
                else:
                    epochs_iterator.set_description('ELBO Loss: ' +
                                                    '{:.4e}'.format(iter_loss))

                if epoch % 20 == 0:
                    torch.cuda.empty_cache()

            if train_proportion is not None:
                # rescale loss
                self.validation_hist[name] = [
                    i / (1 - train_proportion)
                    for i in self.validation_hist[name]
                ]
                self.hist[name] = [
                    i / train_proportion for i in self.hist[name]
                ]

                # reassing the main loss to be displayed
                self.training_hist[name] = self.hist[name]
                self.hist[name] = self.validation_hist[name]

                if tracking:
                    for k, v in self.logp_hist[name].items():
                        self.logp_hist[name][k] = [
                            i / train_proportion
                            for i in self.logp_hist[name][k]
                        ]
                        self.logp_hist_val[name][k] = [
                            i / (1 - train_proportion)
                            for i in self.logp_hist_val[name][k]
                        ]

                    self.logp_hist_train[name] = self.logp_hist[name]
                    self.logp_hist[name] = self.logp_hist_val[name]

            if self.verbose:
                print(plt.plot(np.log10(self.hist[name][0:])))
Beispiel #7
0
    def fit_advi_iterative_simple(
        self,
        n: int = 3,
        method='advi',
        n_type='restart',
        n_iter=None,
        learning_rate=None,
        progressbar=True,
    ):
        r""" Find posterior using ADVI (deprecated)
        (maximising likehood of the data and minimising KL-divergence of posterior to prior)
        :param n: number of independent initialisations
        :param method: which approximation of the posterior (guide) to use?.
            * ``'advi'`` - Univariate normal approximation (pyro.infer.autoguide.AutoDiagonalNormal)
            * ``'custom'`` - Custom guide using conjugate posteriors
        :return: self.svi dictionary with svi pyro objects for each n, and sefl.elbo dictionary storing training history. 
        """

        # Pass data to pyro / pytorch
        self.x_data = torch.tensor(self.X_data.astype(
            self.data_type))  # .double()

        # initialise parameter store
        self.svi = {}
        self.hist = {}
        self.guide_i = {}
        self.samples = {}
        self.node_samples = {}

        self.n_type = n_type

        if n_iter is None:
            n_iter = self.n_iter

        if learning_rate is None:
            learning_rate = self.learning_rate

        if np.isin(n_type, ['bootstrap']):
            if self.X_data_sample is None:
                self.bootstrap_data(n=n)
        elif np.isin(n_type, ['cv']):
            self.generate_cv_data()  # cv data added to self.X_data_sample

        init_names = ['init_' + str(i + 1) for i in np.arange(n)]

        for i, name in enumerate(init_names):

            # initialise Variational distributiion = guide
            if method is 'advi':
                self.guide_i[name] = AutoGuideList(self.model)
                self.guide_i[name].append(
                    AutoNormal(poutine.block(self.model,
                                             expose_all=True,
                                             hide_all=False,
                                             hide=self.point_estim),
                               init_loc_fn=init_to_mean))
                self.guide_i[name].append(
                    AutoDelta(
                        poutine.block(self.model,
                                      hide_all=True,
                                      expose=self.point_estim)))
            elif method is 'custom':
                self.guide_i[name] = self.guide

            # initialise SVI inference method
            self.svi[name] = SVI(
                self.model,
                self.guide_i[name],
                optim.ClippedAdam({
                    'lr': learning_rate,
                    # limit the gradient step from becoming too large
                    'clip_norm': self.total_grad_norm_constraint
                }),
                loss=JitTrace_ELBO())

            pyro.clear_param_store()

            # record ELBO Loss history here
            self.hist[name] = []

            # pick dataset depending on the training mode and move to GPU
            if np.isin(n_type, ['cv', 'bootstrap']):
                self.x_data = torch.tensor(self.X_data_sample[i].astype(
                    self.data_type))
            else:
                self.x_data = torch.tensor(self.X_data.astype(self.data_type))

            if self.use_cuda:
                # move tensors and modules to CUDA
                self.x_data = self.x_data.cuda()

            # train for n_iter
            it_iterator = tqdm(range(n_iter))
            for it in it_iterator:

                hist = self.svi[name].step(self.x_data)
                it_iterator.set_description('ELBO Loss: ' +
                                            str(np.round(hist, 3)))
                self.hist[name].append(hist)

                # if it % 50 == 0 & self.verbose:
                # logging.info("Elbo loss: {}".format(hist))
                if it % 500 == 0:
                    torch.cuda.empty_cache()
Beispiel #8
0
def run_svi(
    model,
    guide,
    train_data,
    unsplit_data,
    subsample=False,
    num_iters=10000,
    lr=1e-2,
    zero_inflated=False,
):
    pyro.clear_param_store()

    _, p_data, y, p_types, p_stories, p_subreddits = train_data

    t_data, s_data, r_data = unsplit_data

    if subsample:
        p_data = p_data[:250]
        y = y[:250]
        p_types = p_types[:250]
        p_stories = p_stories[:250]
        p_subreddits = p_subreddits[:250]

    svi = SVI(model, guide, optim.ClippedAdam({"lr": lr}), loss=Trace_ELBO())

    pyro.clear_param_store()
    losses = np.zeros((num_iters, ))

    start_time = time()

    for i in range(num_iters):
        elbo = svi.step(
            p_data,
            t_data,
            s_data,
            r_data,
            y,
            p_types,
            p_stories,
            p_subreddits,
            zero_inflated,
        )
        losses[i] = elbo
        if i % 100 == 99:
            elapsed = time() - start_time
            remaining = (elapsed / (i + 1)) * (num_iters - i)
            print(
                f"Iter {i+1}/{num_iters}"
                "\t||\t"
                "Elbo loss:"
                f"{elbo:.2e}"
                "\t||\t"
                "Time Elapsed:"
                f"{int(elapsed) // 60:02}:{int(elapsed) % 60:02}"
                "\t||\t"
                f"Est Remaining:"
                f"{int(remaining) // 60:02}:{int(remaining) % 60:02}",
                end="\r",
                flush=True,
            )
    return svi, losses
Beispiel #9
0
                       constraint=constraints.positive)
    w_scale = pyro.param("w_scale", torch.rand(n_predictors), 
                         constraint=constraints.positive)

    w = pyro.sample("w", dist.Gamma(w_loc, w_scale))

    b_loc = pyro.param("b_loc", torch.rand(1))
    b_scale = pyro.param("b_scale", torch.rand(1), constraint=constraints.positive)

    b = pyro.sample("b", dist.LogNormal(b_loc, b_scale))


pyro.clear_param_store()

death_svi = SVI(model=death_model, guide=death_guide, 
                optim=optim.ClippedAdam({'lr' : 0.01}), 
                loss=Trace_ELBO())

for step in range(2000):
    loss = death_svi.step(X_train, y_train)/len(X_train)
    if step % 100 == 0:
        print(f"Step {step} : loss = {loss}")


print("Inferred params:", list(pyro.get_param_store().keys()), end="\n\n")
# w_i and b posterior mean
inferred_w = pyro.get_param_store()["w_loc"]
inferred_b = pyro.get_param_store()["b_loc"]


for i,w in enumerate(inferred_w):
def main(model, guide, args):
    # init
    if args.seed is not None: pyro.set_rng_seed(args.seed)
    logger = get_logger(args.log, __name__)
    logger.info(args)

    # load data
    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths    = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']

    d1_training = int(len(training_seq_lengths)/args.mini_batch_size)*args.mini_batch_size
    training_seq_lengths    = training_seq_lengths   [:d1_training]
    training_data_sequences = training_data_sequences[:d1_training]

    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    logger.info(f"N_train_data: {N_train_data}\t"
                f"avg. training seq. length: {training_seq_lengths.float().mean():.2f}\t"
                f"N_mini_batches: {N_mini_batches}")
    
    # setup svi
    pyro.clear_param_store()
    opt = optim.ClippedAdam({"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                             "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                             "weight_decay": args.weight_decay})
    svi = SVI(model.main, guide.main, opt,
              loss=Trace_ELBO())
    svi_eval = SVI(model.main, guide.main, opt,
                   loss=Trace_ELBO())

    # train minibatch
    def proc_minibatch(svi_proc, epoch, which_mini_batch, shuffled_indices):
        # compute the KL annealing factor approriate for the current mini-batch in the current epoch
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)

        # do an actual gradient step (or eval)
        loss = svi_proc(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        return loss

    # train epoch
    def train_epoch(epoch):
        # take gradient
        loss, shuffled_indices = 0.0, torch.randperm(N_train_data)
        for which_mini_batch in range(N_mini_batches):
            loss += proc_minibatch(svi.step, epoch,
                                   which_mini_batch, shuffled_indices)
        loss /= N_train_time_slices
        return loss

    # eval loss of epoch
    def eval_epoch():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        guide.rnn.eval()

        # eval loss
        loss, shuffled_indices = 0.0, torch.randperm(N_train_data)
        for which_mini_batch in range(N_mini_batches):
            loss += proc_minibatch(svi_eval.evaluate_loss, 0,
                                   which_mini_batch, shuffled_indices)
        loss /= N_train_time_slices

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        guide.rnn.train()
        return loss

    # train
    #elbo_l = []
    #param_state_l = []
    time_l = [time.time()]
    logger.info(f"\nepoch\t"+"elbo\t"+"time(sec)")

    for epoch in range(1, args.num_epochs+1):
        loss = train_epoch(epoch)
        #elbo_l.append(-loss)

        # param_state = copy.deepcopy(pyro.get_param_store().get_state())
        # param_state_l.append(param_state)

        time_l.append(time.time())
        logger.info(f"{epoch:04d}\t"
                    f"{-loss:.4f}\t"
                    f"{time_l[-1]-time_l[-2]:.3f}")

        if math.isnan(loss): break