Exemple #1
0
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm,
          writer):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    acc_meter = utils.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()
            loss = loss_fn(z, sldj=sldj)
            loss_meter.update(loss.item(), x.size(0))

            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))
    writer.add_scalar("train/loss", loss_meter.avg, epoch)
    writer.add_scalar("train/bpd", utils.bits_per_dim(x, loss_meter.avg),
                      epoch)
def test(epoch, net, testloader, device, loss_fn, writer, tb_name="test"):
    net.eval()
    loss_meter = utils.AverageMeter()
    loss_list = []
    with torch.no_grad():
        with tqdm(total=len(testloader.dataset)) as progress_bar:
            for x, _ in testloader:
                x = x.to(device)
                z = net(x)
                sldj = net.module.logdet()
                losses = loss_fn(z, sldj=sldj, mean=False)
                loss_list.extend([loss.item() for loss in losses])

                loss = losses.mean()
                loss_meter.update(loss.item(), x.size(0))

                progress_bar.set_postfix(loss=loss_meter.avg,
                                         bpd=utils.bits_per_dim(
                                             x, loss_meter.avg))
                progress_bar.update(x.size(0))

    likelihoods = -torch.from_numpy(np.array(loss_list)).float()
    if writer is not None:
        writer.add_scalar("{}/loss".format(tb_name), loss_meter.avg, epoch)
        writer.add_scalar("{}/bpd".format(tb_name),
                          utils.bits_per_dim(x, loss_meter.avg), epoch)
        writer.add_histogram('{}/likelihoods'.format(tb_name), likelihoods,
                             epoch)
    return likelihoods
Exemple #3
0
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm,
          writer):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    loss_unsup_meter = utils.AverageMeter()
    loss_nll_meter = utils.AverageMeter()
    acc_meter = utils.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, y in trainloader:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()

            logits = loss_fn.prior.class_logits(z.reshape((len(z), -1)))
            loss_nll = F.cross_entropy(logits, y)

            loss_unsup = loss_fn(z, sldj=sldj)
            loss = loss_nll + loss_unsup

            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            preds = torch.argmax(logits, dim=1)
            acc = (preds == y).float().mean().item()

            acc_meter.update(acc, x.size(0))
            loss_meter.update(loss.item(), x.size(0))
            loss_unsup_meter.update(loss_unsup.item(), x.size(0))
            loss_nll_meter.update(loss_nll.item(), x.size(0))

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(
                                         x, loss_unsup_meter.avg),
                                     acc=acc_meter.avg)
            progress_bar.update(x.size(0))
    x_img = torchvision.utils.make_grid(x[:10],
                                        nrow=2,
                                        padding=2,
                                        pad_value=255)
    writer.add_image("data/x", x_img)
    writer.add_scalar("train/loss", loss_meter.avg, epoch)
    writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg, epoch)
    writer.add_scalar("train/loss_nll", loss_nll_meter.avg, epoch)
    writer.add_scalar("train/acc", acc_meter.avg, epoch)
    writer.add_scalar("train/bpd", utils.bits_per_dim(x, loss_unsup_meter.avg),
                      epoch)
Exemple #4
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn,
          max_grad_norm, conditional):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x in trainloader:
            optimizer.zero_grad()
            if conditional:
                x, x2 = x
                x = x.to(device)
                x2 = x2.to(device)
                z, sldj = net(x, x2, reverse=False)
            else:
                x = x.to(device)
                z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)
def train(epoch,
          net,
          trainloader,
          device,
          optimizer,
          loss_fn,
          max_grad_norm,
          writer,
          num_samples=10,
          sampling=True,
          tb_freq=100):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    iter_count = 0
    batch_count = 0
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            iter_count += 1
            batch_count += x.size(0)
            x = x.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()
            loss = loss_fn(z, sldj=sldj)
            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

            if iter_count % tb_freq == 0 or batch_count == len(
                    trainloader.dataset):
                tb_step = epoch * (len(trainloader.dataset)) + batch_count
                writer.add_scalar("train/loss", loss_meter.avg, tb_step)
                writer.add_scalar("train/bpd",
                                  utils.bits_per_dim(x, loss_meter.avg),
                                  tb_step)
                if sampling:
                    net.eval()
                    draw_samples(net, writer, loss_fn, num_samples, device,
                                 tuple(x[0].shape), tb_step)
                    net.train()
Exemple #6
0
def test(epoch, net, testloader, device, loss_fn, num_samples, writer):
    net.eval()
    loss_meter = utils.AverageMeter()
    with torch.no_grad():
        with tqdm(total=len(testloader.dataset)) as progress_bar:
            for x, _ in testloader:
                x = x.to(device)
                z = net(x)
                sldj = net.module.logdet()
                loss = loss_fn(z, sldj=sldj)
                loss_meter.update(loss.item(), x.size(0))

                progress_bar.set_postfix(loss=loss_meter.avg,
                                         bpd=utils.bits_per_dim(
                                             x, loss_meter.avg))
                progress_bar.update(x.size(0))
    if writer is not None:
        writer.add_scalar("test/loss", loss_meter.avg, epoch)
        writer.add_scalar("test/bpd", utils.bits_per_dim(x, loss_meter.avg),
                          epoch)
Exemple #7
0
def test(net, testloader, device, loss_fn):
    net.eval()
    ys_lst = []
    preds_lst = []
    probs_lst = []
    probs_x_lst = []
    total_loss = 0.
    with torch.no_grad():
        with tqdm(total=len(testloader.dataset)) as progress_bar:
            for x, y in testloader:
                x = x.to(device)
                y = y.to(device).cpu().numpy().reshape((-1, 1))
                z, sldj = net(x, reverse=False)

                loss = loss_fn(z, sldj=sldj)
                total_loss += loss * x.size(0)
                z = z.reshape((len(z), -1))

                probs_x = loss_fn.prior.log_prob(z) + sldj
                probs_x = probs_x.cpu().numpy()
                probs_x = probs_x.reshape((-1, 1))

                probs = loss_fn.prior.class_probs(z).cpu().numpy()
                preds = np.argmax(probs, axis=1)
                preds = preds.reshape((-1, 1))

                ys_lst.append(y)
                preds_lst.append(preds)
                probs_lst.append(probs)
                probs_x_lst.append(probs_x)

                progress_bar.update(x.size(0))

    ys = np.vstack(ys_lst)
    probs = np.vstack(probs_lst)
    probs_x = np.vstack(probs_x_lst)
    preds = np.vstack(preds_lst)
    loss = total_loss / len(ys)
    acc = (ys == preds).mean()
    bpd = utils.bits_per_dim(x, loss)

    return acc, bpd, loss, ys, probs, probs_x, preds
Exemple #8
0
def test(epoch, net, testloader, device, loss_fn, num_samples, conditional,
         name):
    global best_loss
    net.eval()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x in testloader:
            if conditional:
                x1, x2 = x
                x1 = x1.to(device)
                x2 = x2.to(device)
                z, sldj = net(x1, x2, reverse=False)
            else:
                x = x.to(device)
                z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

    # Save checkpoint
    if loss_meter.avg < best_loss:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'test_loss': loss_meter.avg,
            'epoch': epoch,
        }
        os.makedirs('ckpts', exist_ok=True)
        torch.save(state, 'ckpts/{}_best.pth.tar'.format(name))
        best_loss = loss_meter.avg

    # Save samples and data
    images = sample(net, num_samples, device)
    os.makedirs('samples', exist_ok=True)
    images_concat = torchvision.utils.make_grid(images,
                                                nrow=int(num_samples**0.5),
                                                padding=2,
                                                pad_value=255)
    torchvision.utils.save_image(images_concat,
                                 'samples/{}_epoch_{}.png'.format(name, epoch))
Exemple #9
0
def train(epoch,
          net,
          trainloader,
          device,
          optimizer,
          loss_fn,
          max_grad_norm,
          writer,
          num_samples=10,
          sampling=True,
          tb_freq=100):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    loss_unsup_meter = utils.AverageMeter()
    loss_reconstr_meter = utils.AverageMeter()
    kl_loss_meter = utils.AverageMeter()
    acc_meter = utils.AverageMeter()
    iter_count = 0
    batch_count = 0
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            iter_count += 1
            batch_count += x.size(0)
            x = x.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()
            loss_unsup = loss_fn(z, sldj=sldj)

            # if vae_loss:
            #     logvar_z = -logvar_net(z)
            #     z_perturbed = z + torch.randn_like(z) * torch.exp(0.5 * logvar_z)
            #     x_reconstr = net.module.inverse(z_perturbed)
            #     if decoder_likelihood == 'binary_ce':
            #         loss_reconstr = F.binary_cross_entropy(x_reconstr, x, reduction='sum') / x.size(0)
            #     else:
            #         loss_reconstr = F.mse_loss(x_reconstr, x, reduction='sum') / x.size(0)
            #     kl_loss = -0.5 * (logvar_z - logvar_z.exp()).sum(dim=[1])
            #     kl_loss = kl_loss.mean()
            #     loss = loss_unsup + loss_reconstr * reconstr_weight + kl_loss * reconstr_weight
            # else:
            logvar_z = torch.tensor([0.])
            loss_reconstr = torch.tensor([0.])
            kl_loss = torch.tensor([0.])
            loss = loss_unsup

            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            loss_unsup_meter.update(loss_unsup.item(), x.size(0))
            loss_reconstr_meter.update(loss_reconstr.item(), x.size(0))
            kl_loss_meter.update(kl_loss.item(), x.size(0))
            loss_meter.update(loss.item(), x.size(0))

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

            if iter_count % tb_freq == 0 or batch_count == len(
                    trainloader.dataset):
                tb_step = epoch * (len(trainloader.dataset)) + batch_count
                writer.add_scalar("train/loss", loss_meter.avg, tb_step)
                writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg,
                                  tb_step)
                writer.add_scalar("train/loss_reconstr",
                                  loss_reconstr_meter.avg, tb_step)
                writer.add_scalar("train/kl_loss", kl_loss_meter.avg, tb_step)
                writer.add_scalar("train/bpd",
                                  utils.bits_per_dim(x, loss_unsup_meter.avg),
                                  tb_step)
                writer.add_histogram('train/logvar_z', logvar_z, tb_step)
                if sampling:
                    net.eval()
                    draw_samples(net, writer, loss_fn, num_samples, device,
                                 tuple(x[0].shape), tb_step)
                    net.train()
def train(
    epoch,
    net,
    trainloader,
    device,
    optimizer,
    loss_fn,
    label_weight,
    max_grad_norm,
    writer,
    use_unlab=True,
):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    loss_unsup_meter = utils.AverageMeter()
    loss_nll_meter = utils.AverageMeter()
    jaclogdet_meter = utils.AverageMeter()
    acc_meter = utils.AverageMeter()
    with tqdm(total=trainloader.batch_sampler.num_labeled) as progress_bar:
        for x1, y in trainloader:

            x1 = x1.to(device)
            y = y.to(device)

            labeled_mask = (y != NO_LABEL)

            optimizer.zero_grad()

            z1 = net(x1)
            sldj = net.module.logdet()

            z_labeled = z1.reshape((len(z1), -1))
            z_labeled = z_labeled[labeled_mask]
            y_labeled = y[labeled_mask]

            logits_labeled = loss_fn.prior.class_logits(z_labeled)
            loss_nll = F.cross_entropy(logits_labeled, y_labeled)

            if use_unlab:
                loss_unsup = loss_fn(z1, sldj=sldj)
                loss = loss_nll * label_weight + loss_unsup
            else:
                loss_unsup = torch.tensor([0.])
                loss = loss_nll

            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            preds = torch.argmax(logits_labeled, dim=1)
            acc = (preds == y_labeled).float().mean().item()

            acc_meter.update(acc, x1.size(0))
            loss_meter.update(loss.item(), x1.size(0))
            loss_unsup_meter.update(loss_unsup.item(), x1.size(0))
            loss_nll_meter.update(loss_nll.item(), x1.size(0))
            jaclogdet_meter.update(sldj.mean().item(), x1.size(0))

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=utils.bits_per_dim(
                                         x1, loss_unsup_meter.avg),
                                     acc=acc_meter.avg)
            progress_bar.update(y_labeled.size(0))

    writer.add_scalar("train/loss", loss_meter.avg, epoch)
    writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg, epoch)
    writer.add_scalar("train/loss_nll", loss_nll_meter.avg, epoch)
    writer.add_scalar("train/jaclogdet", jaclogdet_meter.avg, epoch)
    writer.add_scalar("train/acc", acc_meter.avg, epoch)
    writer.add_scalar("train/bpd",
                      utils.bits_per_dim(x1, loss_unsup_meter.avg), epoch)
def train(epoch, net, trainloader, ood_loader, device, optimizer, loss_fn, 
         max_grad_norm, writer, negative_val=-1e5, num_samples=10, tb_freq=100):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = utils.AverageMeter()
    loss_positive_meter = utils.AverageMeter()
    loss_negative_meter = utils.AverageMeter()
    iter_count = 0
    batch_count = 0
    pooler = MedianPool2d(7, padding=3)
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for (x, _), (x_transposed, _) in zip(trainloader, ood_loader):

            bs = x.shape[0]
            x = torch.cat((x, x_transposed), dim=0)
            iter_count += 1
            batch_count += bs
            x = x.to(device)
            optimizer.zero_grad()
            z = net(x)
            sldj = net.module.logdet()
            loss = loss_fn(z, sldj=sldj, mean=False)
            loss[bs:] *= (-1)
            loss_positive = loss[:bs]
            loss_negative = loss[bs:]
            if (loss_negative > negative_val).sum() > 0:
                loss_negative = loss_negative[loss_negative > negative_val]
                loss_negative = loss_negative.mean()
                loss_positive = loss_positive.mean()
                loss = 0.5*(loss_positive + loss_negative)
            else:
                loss_negative = torch.tensor(0.)
                loss_positive = loss_positive.mean()
                loss = loss_positive
            loss.backward()
            utils.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            loss_meter.update(loss.item(), bs)
            loss_positive_meter.update(loss_positive.item(), bs)
            loss_negative_meter.update(loss_negative.item(), bs)
            progress_bar.set_postfix(
                pos_bpd=utils.bits_per_dim(x[:bs], loss_positive_meter.avg),
                neg_bpd=utils.bits_per_dim(x[bs:], -loss_negative_meter.avg),
                neg_loss=loss_negative.mean().item())
            progress_bar.update(bs)

            if iter_count % tb_freq == 0 or batch_count == len(trainloader.dataset):
                tb_step = epoch*(len(trainloader.dataset))+batch_count
                writer.add_scalar("train/loss", loss_meter.avg, tb_step)
                writer.add_scalar("train/loss_positive", loss_positive_meter.avg, tb_step)
                writer.add_scalar("train/loss_negative", loss_negative_meter.avg, tb_step)
                writer.add_scalar("train/bpd_positive", utils.bits_per_dim(x[:bs], loss_positive_meter.avg), tb_step)
                writer.add_scalar("train/bpd_negative", utils.bits_per_dim(x[bs:], -loss_negative_meter.avg), tb_step)
                x1_img = torchvision.utils.make_grid(x[:10], nrow=2 , padding=2, pad_value=255)
                x2_img = torchvision.utils.make_grid(x[-10:], nrow=2 , padding=2, pad_value=255)
                writer.add_image("data/x", x1_img)
                writer.add_image("data/x_transposed", x2_img)
                net.eval()
                draw_samples(net, writer, loss_fn, num_samples, device, tuple(x[0].shape), tb_step)
                net.train()