Example #1
0
def train(model,
          train_loader,
          device,
          tqdm,
          writer,
          lr,
          lr_gamma,
          lr_milestones,
          iw,
          iter_max=np.inf,
          iter_save=np.inf,
          model_name='model',
          reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=lr_milestones,
                                                     gamma=lr_gamma)
    i = 0
    # model.warmup = True
    # print("warmup", model.warmup)
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch_idx, sample in enumerate(train_loader):
                i += 1  # i is num of gradient steps taken by end of loop iteration
                if i == (iter_max // 4):
                    # start learning variance
                    model.warmup = False
                optimizer.zero_grad()
                x = torch.tensor(sample).float().to(device)
                loss, summaries = model.loss(x, iw)

                loss.backward()
                optimizer.step()
                scheduler.step()

                # Feel free to modify the progress bar
                pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)

                # Log summaries
                if i % 50 == 0:
                    ut.log_summaries(writer, summaries, i)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)
                    # print(optimizer.param_groups[0]['lr'])
                    # print("warmup", model.warmup)
                    print("\n",
                          [(key, v.item()) for key, v in summaries.items()])

                if i == iter_max:
                    return
Example #2
0
def train(model, train_loader, labeled_subset, device, tqdm, writer,
          iter_max=np.inf, iter_save=np.inf,
          model_name='model', reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    i = 0
    epoch_count = 0
    mus = []
    its = []
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch_idx, (xu, yu) in enumerate(train_loader):
                i += 1 # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()

                xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                yu = yu.new(np.eye(10)[yu]).to(device).float()
                loss, summaries = model.loss(xu, epoch_count)

                loss.backward()
                # for name, param in model.named_parameters():
                #     if name in ['mu']:
                #         param.retain_grad()
                #         print(param.requires_grad, param.grad)

                # # start debugger
                # import pdb; pdb.set_trace()            
                optimizer.step()

                # Feel free to modify the progress bar
                pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)

                # Log summaries
                if i % 50 == 0:
                    ut.log_summaries(writer, summaries, i)
                    for name, param in model.named_parameters():
                        if name in ['mu']:
                            mus.append(param)
                            its.append(i)
                    
                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)

                if i == iter_max:
                    plot_mus(its,mus)
                    return
            epoch_count += 1
Example #3
0
def train(model,
          train_loader,
          labeled_subset,
          device,
          tqdm,
          writer,
          iter_max=np.inf,
          iter_save=np.inf,
          model_name='model',
          y_status='none',
          reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    i = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch_idx, xu in enumerate(train_loader):
                i += 1  # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()

                if y_status == 'none':
                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    loss, summaries = model.loss(xu)

                loss.backward()
                optimizer.step()

                # Feel free to modify the progress bar
                if y_status == 'none':
                    pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)

                # Log summaries
                if i % 50 == 0: ut.log_summaries(writer, summaries, i)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)

                if i == iter_max:
                    return
Example #4
0
def train(model,
          data,
          device,
          tqdm,
          kernel,
          iter_max=np.inf,
          iter_save=np.inf,
          iter_plot=np.inf,
          reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    i = 0
    loss_list = []
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch in data:
                i += 1  # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()

                loss = model.loss(batch)
                loss_list.append(loss)

                loss.backward()
                optimizer.step()

                pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)

                if i % iter_plot == 0:
                    if model.input_feat_dim <= 2:
                        test_plot(model, i, kernel)
                    ut.plot_log_loss(model, loss_list, i)

                if i == iter_max:
                    return
Example #5
0
def train_c(model,
            train_loader,
            train_loader_ev,
            device,
            tqdm,
            writer,
            lr,
            lr_gamma,
            lr_milestones,
            iw,
            iter_max=np.inf,
            iter_save=np.inf,
            model_name='model',
            reinitialize=False):
    assert isinstance(model, CVAE)

    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=lr_milestones,
                                                     gamma=lr_gamma)

    # model.warmup = True
    # print("warmup", model.warmup)

    iterator = iter(train_loader)

    iterator_ev = iter(train_loader_ev)
    i = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            i += 1  # i is num of gradient steps taken by end of loop iteration
            if i == (iter_max // 4):
                # start learning variance
                model.warmup = False
            optimizer.zero_grad()

            # must handle two data-loader queues...
            try:
                sample = next(iterator)
            except StopIteration:
                iterator = iter(train_loader)
                sample = next(iterator)
            try:
                sample_ev = next(iterator_ev)
            except StopIteration:
                iterator_ev = iter(train_loader_ev)
                sample_ev = next(iterator_ev)

            # combine the batches
            for k, v in sample.items():
                sample[k] = torch.tensor(v).float().to(device)
            for k, v in sample_ev.items():
                sample[k] = torch.tensor(v).float().to(device)
            sample_ev = None

            # run model
            loss, summaries = model.loss(sample, iw)

            loss.backward()
            optimizer.step()
            scheduler.step()

            # Feel free to modify the progress bar
            pbar.set_postfix(loss='{:.2e}'.format(loss))
            pbar.update(1)

            # Log summaries
            if i % 50 == 0:
                ut.log_summaries(writer, summaries, i)

            # Save model
            if i % iter_save == 0:
                ut.save_model_by_name(model, i)
                # print(optimizer.param_groups[0]['lr'])
                # print("warmup", model.warmup)
                print("\n", [(key, v.item()) for key, v in summaries.items()])

            if i == iter_max:
                return
Example #6
0
def train2(model,
           train_loader,
           val_set,
           tqdm,
           lr,
           lr_gamma,
           lr_milestone_every,
           iw,
           num_epochs,
           reinitialize=False,
           is_car_model=False):
    assert isinstance(model, VAE2)
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)

    num_batches_per_epoch = len(
        train_loader.dataset) // train_loader.batch_size
    lr_milestones = [(1 + lr_milestone_every * i) * num_batches_per_epoch
                     for i in range(num_epochs // lr_milestone_every + 1)]
    print("len(train_loader.dataset)", len(train_loader.dataset),
          "num_batches_per_epoch", num_batches_per_epoch)
    print("lr_milestones", lr_milestones, "lr",
          [lr * lr_gamma**i for i in range(len(lr_milestones))])

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=lr_milestones,
                                                     gamma=lr_gamma)

    random.seed(1234)  # Initialize the random seed
    # model.warmup = True
    # print("warmup", model.warmup)
    i = 0
    epoch = 0
    print_every = 100
    summaries = OrderedDict({
        'epoch': 0,
        'loss': 0,
        'kl_z': 0,
        'rec_mse': 0,
        'rec_var': 0,
        'loss_type': iw,
        'lr': optimizer.param_groups[0]['lr'],
        'varpen': model.var_pen,
    })

    with tqdm(total=num_batches_per_epoch * num_epochs) as pbar:
        while epoch < num_epochs:
            if epoch >= 1:
                # start learning variance
                model.warmup = False

            summaries['loss_type'] = iw
            summaries['lr'] = optimizer.param_groups[0]['lr']
            summaries['varpen'] = model.var_pen

            for batch_idx, sample in enumerate(train_loader):
                i += 1  # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()

                # run model
                loss, info = model.loss(
                    x=sample["other"] if not is_car_model else sample["car"],
                    meta=sample["meta"],
                    c=None if not is_car_model else sample["other"],
                    iw=iw)

                pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)
                # Log summaries
                for key, value in info.items():
                    summaries[key] += value.item() / (1.0 * print_every)

                if i % print_every == 0:
                    summaries["epoch"] = epoch + (batch_idx *
                                                  1.0) / num_batches_per_epoch
                    ut.log_summaries(summaries,
                                     'train',
                                     model_name=model.name,
                                     verbose=True)
                    for key in ['loss', 'kl_z', 'rec_mse', 'rec_var']:
                        summaries[key] = 0.0
                loss.backward()
                optimizer.step()
                scheduler.step()

            # validate at each epoch end
            val_summaries = copy.deepcopy(summaries)
            for key in ['loss', 'kl_z', 'rec_mse', 'rec_var']:
                val_summaries[key] = 0.0
            val_summaries["epoch"] = epoch + 1
            ut.evaluate_lower_bound2(model,
                                     val_set,
                                     run_iwae=(iw >= 1),
                                     mode='val',
                                     verbose=False,
                                     summaries=val_summaries)

            epoch += 1
            if epoch % 10 == 0:  # save interim model
                ut.save_model_by_name(model, epoch)

        # save in the end
        ut.save_model_by_name(model, epoch)
Example #7
0
File: train.py Project: qqhann/VAEs
def train(model,
          train_loader,
          labeled_subset,
          device,
          tqdm,
          writer,
          iter_max=np.inf,
          iter_save=np.inf,
          model_name='model',
          y_status='none',
          reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    i = 0
    #beta = ut.DeterministicWarmup(n=50, t_max = 1)
    beta = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch_idx, (xu, yu) in enumerate(train_loader):
                i += 1  # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()
                if i % 600 == 0:
                    beta += 0.02
                #print(beta)
                if y_status == 'none':
                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    loss, summaries = model.loss(xu, beta)
                    #print(beta)
                    #print(next(beta))
                    #loss, summaries = model.loss(xu)
                elif y_status == 'semisup':
                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    # xl and yl already preprocessed
                    xl, yl = labeled_subset
                    xl = torch.bernoulli(xl)
                    loss, summaries = model.loss(xu, xl, yl)

                    # Add training accuracy computation
                    pred = model.cls.classify(xu).argmax(1)
                    true = yu.argmax(1)
                    acc = (pred == true).float().mean()
                    summaries['class/acc'] = acc

                elif y_status == 'fullsup':
                    # Janky code: fullsup is only for SVHN
                    # xu is not bernoulli for SVHN
                    xu = xu.to(device).reshape(xu.size(0), -1)
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    loss, summaries = model.loss(xu, yu)

                loss.backward()
                optimizer.step()

                # Feel free to modify the progress bar
                if y_status == 'none':
                    pbar.set_postfix(loss='{:.2e}'.format(loss),
                                     kl='{:.2f}'.format(summaries['gen/kl_z']),
                                     rec='{:.2g}'.format(summaries['gen/rec']))
                elif y_status == 'semisup':
                    pbar.set_postfix(loss='{:.2e}'.format(loss),
                                     acc='{:.2e}'.format(acc))
                elif y_status == 'fullsup':
                    pbar.set_postfix(loss='{:.2e}'.format(loss),
                                     kl='{:.2e}'.format(summaries['gen/kl_z']))
                pbar.update(1)

                # Log summaries
                if i % 50 == 0: ut.log_summaries(writer, summaries, i)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)

                if i == iter_max:
                    return
Example #8
0
def train(model, train_data, batch_size, n_batches, 
            lr=1e-3,
            clip_grad=None,
            iter_max=np.inf, 
            iter_save=np.inf, 
            iter_plot=np.inf, 
            reinitialize=False,
            kernel=None):

    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    mse = nn.MSELoss()

    # # Model
    # hidden = model.init_hidden(batch_size)

    i = 0 # i is num of gradient steps taken by end of loop iteration
    loss_list = []
    mse_list = []
    with tqdm.tqdm(total=iter_max) as pbar:
        while True:
            for batch in train_data:
                i += 1 
                # print(psutil.virtual_memory())
                optimizer.zero_grad()

                inputs = batch[:model.n_input_steps, :, :]
                targets = batch[model.n_input_steps:, :, :2]
                
                # Since the data is not continued from batch to batch,
                # reinit hidden every batch. (using zeros)
                outputs = model.forward(inputs, targets=targets)
                batch_mean_nll, KL, KL_sharp = model.get_loss(outputs, targets)
                # print(batch_mean_nll, KL, KL_sharp)
                
                # # Re-weighting for minibatches
                NLL_term = batch_mean_nll * model.n_pred_steps

                # Here B = n_batchs, C = 1 (since each sequence is complete)
                KL_term = KL / n_batches

                loss = NLL_term + KL_term

                if model.sharpen:
                    KL_sharp /= n_batches
                    loss += KL_sharp

                loss_list.append(loss.cpu().detach())

                if clip_grad is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

                # Print progress
                if model.likelihood_cost_form == 'gaussian':
                    if model.constant_var:
                        mse_val = mse(outputs, targets) * model.n_pred_steps
                    else:
                        if model.rnn_cell_type == 'FF':
                            mean, var = ut.gaussian_parameters_ff(outputs, dim=0)

                        else:
                            mean, var = ut.gaussian_parameters(outputs, dim=-1)
                        mse_val = mse(mean, targets) * model.n_pred_steps

                elif model.likelihood_cost_form == 'mse':
                    mse_val = batch_mean_nll * model.n_pred_steps
                    
                mse_list.append(mse_val.cpu().detach())

                if i % iter_plot == 0:
                    with torch.no_grad():
                        model.eval()
                        if model.input_feat_dim <= 2:
                            ut.test_plot(model, i, kernel)

                        elif model.input_feat_dim == 4:
                            rand_idx = random.sample(range(batch.shape[1]), 4)
                            full_true_traj = batch[:, rand_idx, :]
                            if not model.BBB:
                                if model.constant_var:
                                    pred_traj = outputs[:, rand_idx, :]
                                    std_pred = None
                                else:
                                    pred_traj = mean[:, rand_idx, :]
                                    std_pred = var.sqrt()

                                ut.plot_highd_traj(model, i, full_true_traj, 
                                    pred_traj, std_pred=std_pred)
                            else:
                                # resample a few forward passes
                                ut.plot_highd_traj_BBB(model, i, full_true_traj, 
                                                        n_resample_weights=10)
    
                        ut.plot_history(model, loss_list, i, obj='loss')
                        ut.plot_history(model, mse_list, i, obj='mse')
                        model.train()

                
                # loss.backward(retain_graph=True)
                loss.backward()
                optimizer.step()
                
                pbar.set_postfix(loss='{:.2e}'.format(loss), 
                                 mse='{:.2e}'.format(mse_val))
                pbar.update(1)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i, only_latest=True)

                if i == iter_max:
                    return
def refine(train_loader_set,
           mean_set,
           variance_set,
           z_dim,
           device,
           tqdm,
           writer,
           iter_max=np.inf,
           iter_save=np.inf,
           model_name='model',
           y_status='none',
           reinitialize=False):
    # Optimization

    i = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            for index, train_loader in enumerate(train_loader_set):
                print("Iteration:", i)
                print("index: ", index)

                z_prior_m = torch.nn.Parameter(mean_set[index].cpu(),
                                               requires_grad=False).to(device)
                z_prior_v = torch.nn.Parameter(variance_set[index].cpu(),
                                               requires_grad=False).to(device)
                vae = VAE(z_dim=z_dim,
                          name=model_name,
                          z_prior_m=z_prior_m,
                          z_prior_v=z_prior_v).to(device)
                optimizer = optim.Adam(vae.parameters(), lr=1e-3)
                if i == 0:
                    print("Load model")
                    ut.load_model_by_name(vae, global_step=20000)
                else:
                    print("Load model")
                    ut.load_model_by_name(vae, global_step=iter_save)
                for batch_idx, (xu, yu) in enumerate(train_loader):
                    # i is num of gradient steps taken by end of loop iteration
                    optimizer.zero_grad()

                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    loss, summaries = vae.loss_encoder(xu)

                    loss.backward()
                    optimizer.step()

                    # Feel free to modify the progress bar

                    pbar.set_postfix(loss='{:.2e}'.format(loss))

                    pbar.update(1)

                    i += 1
                    # Log summaries
                    if i % 50 == 0: ut.log_summaries(writer, summaries, i)

                    if i == iter_max:
                        ut.save_model_by_name(vae, 0)
                        return

                # Save model
                ut.save_model_by_name(vae, iter_save)
Example #10
0
        L, kl, rec, reconstructed_image, _ = lvae.negative_elbo_bound(
            u, l, sample=False)

        dag_param = lvae.dag.A

        h_a = _h_A(dag_param, dag_param.size()[0])
        L = L + 3 * h_a + 0.5 * h_a * h_a

        L.backward()
        optimizer.step()
        total_loss += L.item()
        total_kl += kl.item()
        total_rec += rec.item()

        m = len(train_dataset)
        save_image(u[0],
                   'figs_vae/reconstructed_image_true_{}.png'.format(epoch),
                   normalize=True)
        save_image(reconstructed_image[0],
                   'figs_vae/reconstructed_image_{}.png'.format(epoch),
                   normalize=True)

    if epoch % 1 == 0:
        #print(f"Epoch: {epoch+1}\tL: {total_loss/m:.2f}\tkl: {total_kl/m:.2f}\t rec: {total_rec/m:.2f}")
        print(
            str(epoch) + ' loss:' + str(total_loss / m) + ' kl:' +
            str(total_kl / m) + ' rec:' + str(total_rec / m) + 'm:' + str(m))

    if epoch % args.iter_save == 0:
        ut.save_model_by_name(lvae, epoch)