Example #1
0
def train(model,
          train_loader,
          device,
          tqdm,
          writer,
          lr,
          lr_gamma,
          lr_milestones,
          iw,
          iter_max=np.inf,
          iter_save=np.inf,
          model_name='model',
          reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=lr_milestones,
                                                     gamma=lr_gamma)
    i = 0
    # model.warmup = True
    # print("warmup", model.warmup)
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch_idx, sample in enumerate(train_loader):
                i += 1  # i is num of gradient steps taken by end of loop iteration
                if i == (iter_max // 4):
                    # start learning variance
                    model.warmup = False
                optimizer.zero_grad()
                x = torch.tensor(sample).float().to(device)
                loss, summaries = model.loss(x, iw)

                loss.backward()
                optimizer.step()
                scheduler.step()

                # Feel free to modify the progress bar
                pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)

                # Log summaries
                if i % 50 == 0:
                    ut.log_summaries(writer, summaries, i)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)
                    # print(optimizer.param_groups[0]['lr'])
                    # print("warmup", model.warmup)
                    print("\n",
                          [(key, v.item()) for key, v in summaries.items()])

                if i == iter_max:
                    return
Example #2
0
def train(model, train_loader, labeled_subset, device, tqdm, writer,
          iter_max=np.inf, iter_save=np.inf,
          model_name='model', reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    i = 0
    epoch_count = 0
    mus = []
    its = []
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch_idx, (xu, yu) in enumerate(train_loader):
                i += 1 # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()

                xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                yu = yu.new(np.eye(10)[yu]).to(device).float()
                loss, summaries = model.loss(xu, epoch_count)

                loss.backward()
                # for name, param in model.named_parameters():
                #     if name in ['mu']:
                #         param.retain_grad()
                #         print(param.requires_grad, param.grad)

                # # start debugger
                # import pdb; pdb.set_trace()            
                optimizer.step()

                # Feel free to modify the progress bar
                pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)

                # Log summaries
                if i % 50 == 0:
                    ut.log_summaries(writer, summaries, i)
                    for name, param in model.named_parameters():
                        if name in ['mu']:
                            mus.append(param)
                            its.append(i)
                    
                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)

                if i == iter_max:
                    plot_mus(its,mus)
                    return
            epoch_count += 1
Example #3
0
def train(model,
          train_loader,
          labeled_subset,
          device,
          tqdm,
          writer,
          iter_max=np.inf,
          iter_save=np.inf,
          model_name='model',
          y_status='none',
          reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    i = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch_idx, xu in enumerate(train_loader):
                i += 1  # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()

                if y_status == 'none':
                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    loss, summaries = model.loss(xu)

                loss.backward()
                optimizer.step()

                # Feel free to modify the progress bar
                if y_status == 'none':
                    pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)

                # Log summaries
                if i % 50 == 0: ut.log_summaries(writer, summaries, i)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)

                if i == iter_max:
                    return
Example #4
0
def train_c(model,
            train_loader,
            train_loader_ev,
            device,
            tqdm,
            writer,
            lr,
            lr_gamma,
            lr_milestones,
            iw,
            iter_max=np.inf,
            iter_save=np.inf,
            model_name='model',
            reinitialize=False):
    assert isinstance(model, CVAE)

    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=lr_milestones,
                                                     gamma=lr_gamma)

    # model.warmup = True
    # print("warmup", model.warmup)

    iterator = iter(train_loader)

    iterator_ev = iter(train_loader_ev)
    i = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            i += 1  # i is num of gradient steps taken by end of loop iteration
            if i == (iter_max // 4):
                # start learning variance
                model.warmup = False
            optimizer.zero_grad()

            # must handle two data-loader queues...
            try:
                sample = next(iterator)
            except StopIteration:
                iterator = iter(train_loader)
                sample = next(iterator)
            try:
                sample_ev = next(iterator_ev)
            except StopIteration:
                iterator_ev = iter(train_loader_ev)
                sample_ev = next(iterator_ev)

            # combine the batches
            for k, v in sample.items():
                sample[k] = torch.tensor(v).float().to(device)
            for k, v in sample_ev.items():
                sample[k] = torch.tensor(v).float().to(device)
            sample_ev = None

            # run model
            loss, summaries = model.loss(sample, iw)

            loss.backward()
            optimizer.step()
            scheduler.step()

            # Feel free to modify the progress bar
            pbar.set_postfix(loss='{:.2e}'.format(loss))
            pbar.update(1)

            # Log summaries
            if i % 50 == 0:
                ut.log_summaries(writer, summaries, i)

            # Save model
            if i % iter_save == 0:
                ut.save_model_by_name(model, i)
                # print(optimizer.param_groups[0]['lr'])
                # print("warmup", model.warmup)
                print("\n", [(key, v.item()) for key, v in summaries.items()])

            if i == iter_max:
                return
Example #5
0
def train2(model,
           train_loader,
           val_set,
           tqdm,
           lr,
           lr_gamma,
           lr_milestone_every,
           iw,
           num_epochs,
           reinitialize=False,
           is_car_model=False):
    assert isinstance(model, VAE2)
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)

    num_batches_per_epoch = len(
        train_loader.dataset) // train_loader.batch_size
    lr_milestones = [(1 + lr_milestone_every * i) * num_batches_per_epoch
                     for i in range(num_epochs // lr_milestone_every + 1)]
    print("len(train_loader.dataset)", len(train_loader.dataset),
          "num_batches_per_epoch", num_batches_per_epoch)
    print("lr_milestones", lr_milestones, "lr",
          [lr * lr_gamma**i for i in range(len(lr_milestones))])

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=lr_milestones,
                                                     gamma=lr_gamma)

    random.seed(1234)  # Initialize the random seed
    # model.warmup = True
    # print("warmup", model.warmup)
    i = 0
    epoch = 0
    print_every = 100
    summaries = OrderedDict({
        'epoch': 0,
        'loss': 0,
        'kl_z': 0,
        'rec_mse': 0,
        'rec_var': 0,
        'loss_type': iw,
        'lr': optimizer.param_groups[0]['lr'],
        'varpen': model.var_pen,
    })

    with tqdm(total=num_batches_per_epoch * num_epochs) as pbar:
        while epoch < num_epochs:
            if epoch >= 1:
                # start learning variance
                model.warmup = False

            summaries['loss_type'] = iw
            summaries['lr'] = optimizer.param_groups[0]['lr']
            summaries['varpen'] = model.var_pen

            for batch_idx, sample in enumerate(train_loader):
                i += 1  # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()

                # run model
                loss, info = model.loss(
                    x=sample["other"] if not is_car_model else sample["car"],
                    meta=sample["meta"],
                    c=None if not is_car_model else sample["other"],
                    iw=iw)

                pbar.set_postfix(loss='{:.2e}'.format(loss))
                pbar.update(1)
                # Log summaries
                for key, value in info.items():
                    summaries[key] += value.item() / (1.0 * print_every)

                if i % print_every == 0:
                    summaries["epoch"] = epoch + (batch_idx *
                                                  1.0) / num_batches_per_epoch
                    ut.log_summaries(summaries,
                                     'train',
                                     model_name=model.name,
                                     verbose=True)
                    for key in ['loss', 'kl_z', 'rec_mse', 'rec_var']:
                        summaries[key] = 0.0
                loss.backward()
                optimizer.step()
                scheduler.step()

            # validate at each epoch end
            val_summaries = copy.deepcopy(summaries)
            for key in ['loss', 'kl_z', 'rec_mse', 'rec_var']:
                val_summaries[key] = 0.0
            val_summaries["epoch"] = epoch + 1
            ut.evaluate_lower_bound2(model,
                                     val_set,
                                     run_iwae=(iw >= 1),
                                     mode='val',
                                     verbose=False,
                                     summaries=val_summaries)

            epoch += 1
            if epoch % 10 == 0:  # save interim model
                ut.save_model_by_name(model, epoch)

        # save in the end
        ut.save_model_by_name(model, epoch)
Example #6
0
File: train.py Project: qqhann/VAEs
def train(model,
          train_loader,
          labeled_subset,
          device,
          tqdm,
          writer,
          iter_max=np.inf,
          iter_save=np.inf,
          model_name='model',
          y_status='none',
          reinitialize=False):
    # Optimization
    if reinitialize:
        model.apply(ut.reset_weights)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    i = 0
    #beta = ut.DeterministicWarmup(n=50, t_max = 1)
    beta = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            for batch_idx, (xu, yu) in enumerate(train_loader):
                i += 1  # i is num of gradient steps taken by end of loop iteration
                optimizer.zero_grad()
                if i % 600 == 0:
                    beta += 0.02
                #print(beta)
                if y_status == 'none':
                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    loss, summaries = model.loss(xu, beta)
                    #print(beta)
                    #print(next(beta))
                    #loss, summaries = model.loss(xu)
                elif y_status == 'semisup':
                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    # xl and yl already preprocessed
                    xl, yl = labeled_subset
                    xl = torch.bernoulli(xl)
                    loss, summaries = model.loss(xu, xl, yl)

                    # Add training accuracy computation
                    pred = model.cls.classify(xu).argmax(1)
                    true = yu.argmax(1)
                    acc = (pred == true).float().mean()
                    summaries['class/acc'] = acc

                elif y_status == 'fullsup':
                    # Janky code: fullsup is only for SVHN
                    # xu is not bernoulli for SVHN
                    xu = xu.to(device).reshape(xu.size(0), -1)
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    loss, summaries = model.loss(xu, yu)

                loss.backward()
                optimizer.step()

                # Feel free to modify the progress bar
                if y_status == 'none':
                    pbar.set_postfix(loss='{:.2e}'.format(loss),
                                     kl='{:.2f}'.format(summaries['gen/kl_z']),
                                     rec='{:.2g}'.format(summaries['gen/rec']))
                elif y_status == 'semisup':
                    pbar.set_postfix(loss='{:.2e}'.format(loss),
                                     acc='{:.2e}'.format(acc))
                elif y_status == 'fullsup':
                    pbar.set_postfix(loss='{:.2e}'.format(loss),
                                     kl='{:.2e}'.format(summaries['gen/kl_z']))
                pbar.update(1)

                # Log summaries
                if i % 50 == 0: ut.log_summaries(writer, summaries, i)

                # Save model
                if i % iter_save == 0:
                    ut.save_model_by_name(model, i)

                if i == iter_max:
                    return
def refine(train_loader_set,
           mean_set,
           variance_set,
           z_dim,
           device,
           tqdm,
           writer,
           iter_max=np.inf,
           iter_save=np.inf,
           model_name='model',
           y_status='none',
           reinitialize=False):
    # Optimization

    i = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            for index, train_loader in enumerate(train_loader_set):
                print("Iteration:", i)
                print("index: ", index)

                z_prior_m = torch.nn.Parameter(mean_set[index].cpu(),
                                               requires_grad=False).to(device)
                z_prior_v = torch.nn.Parameter(variance_set[index].cpu(),
                                               requires_grad=False).to(device)
                vae = VAE(z_dim=z_dim,
                          name=model_name,
                          z_prior_m=z_prior_m,
                          z_prior_v=z_prior_v).to(device)
                optimizer = optim.Adam(vae.parameters(), lr=1e-3)
                if i == 0:
                    print("Load model")
                    ut.load_model_by_name(vae, global_step=20000)
                else:
                    print("Load model")
                    ut.load_model_by_name(vae, global_step=iter_save)
                for batch_idx, (xu, yu) in enumerate(train_loader):
                    # i is num of gradient steps taken by end of loop iteration
                    optimizer.zero_grad()

                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    loss, summaries = vae.loss_encoder(xu)

                    loss.backward()
                    optimizer.step()

                    # Feel free to modify the progress bar

                    pbar.set_postfix(loss='{:.2e}'.format(loss))

                    pbar.update(1)

                    i += 1
                    # Log summaries
                    if i % 50 == 0: ut.log_summaries(writer, summaries, i)

                    if i == iter_max:
                        ut.save_model_by_name(vae, 0)
                        return

                # Save model
                ut.save_model_by_name(vae, iter_save)