Пример #1
0
def eval(model, embedder, test_loader):

    ######## just testing sampling ##########
    print("sampling_images")
    model = model.eval()
    sample_image(model,
                 embedder,
                 opt.output_dir,
                 n_row=4,
                 batches_done=0,
                 dataloader=test_loader,
                 device=device)

    return 0.0
    ######## #####################
    print("EVALUATING ON VAL")
    model = model.eval()

    bpd = 0.0
    for i, (imgs, labels, captions) in tqdm(enumerate(test_loader)):
        imgs = imgs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            condition_embd = embedder(labels, captions)
            outputs = model.forward(imgs, condition_embd)
            loss = outputs['loss'].mean()
            bpd += loss / np.log(2)
    bpd /= len(test_loader)
    print("VAL bpd : {}".format(bpd))
    return bpd
Пример #2
0
def train(model, embedder, optimizer, scheduler, train_loader, val_loader,
          opt):
    print("TRAINING STARTS")
    for epoch in range(opt.n_epochs):
        model = model.train()
        loss_to_log = 0.0
        for i, (imgs, labels, captions) in enumerate(train_loader):
            start_batch = time.time()
            imgs = imgs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                condition_embd = embedder(labels, captions)

            optimizer.zero_grad()
            outputs = model.forward(imgs.float(), condition_embd.float())
            loss = outputs['loss'].mean()
            loss.backward()
            optimizer.step()
            batches_done = epoch * len(train_loader) + i
            writer.add_scalar('train/bpd', loss / np.log(2), batches_done)
            loss_to_log += loss.item()
            if (i + 1) % opt.print_every == 0:
                loss_to_log = loss_to_log / (np.log(2) * opt.print_every)
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [bpd: %f] [Time/batch %.3f]" %
                    (epoch + 1, opt.n_epochs, i + 1, len(train_loader),
                     loss_to_log, time.time() - start_batch))
                if (loss_to_log < 3.95):
                    sample_image(model,
                                 embedder,
                                 opt.output_dir,
                                 n_row=4,
                                 batches_done=batches_done,
                                 dataloader=val_loader,
                                 device=device)
                loss_to_log = 0.0
                torch.save(
                    model.state_dict(),
                    os.path.join(opt.output_dir, 'models',
                                 '2batch_{}.pt'.format(i)))
            if (batches_done + 1) % opt.sample_interval == 0:
                print("sampling_images")
                model = model.eval()
                #sample_image(model, embedder, opt.output_dir, n_row=4,
                #            batches_done=batches_done,
                #           dataloader=val_loader, device=device)

            print('saved', i, '/', len(train_loader))
        val_bpd = eval(model, embedder, val_loader)
        writer.add_scalar("val/bpd", val_bpd, (epoch + 1) * len(train_loader))

        torch.save(
            model.state_dict(),
            os.path.join(opt.output_dir, 'models',
                         'epoch_{}.pt'.format(epoch)))

    scheduler.step()
Пример #3
0
def sample(model, embedder, batches_done, val_loader, device):
    print("sampling_images")
    model = model.eval()
    sample_image(model,
                 embedder,
                 opt.output_dir,
                 n_row=opt.n_row,
                 batches_done=batches_done,
                 dataloader=val_loader,
                 device=device)
Пример #4
0
def eval(model, embedder, test_loader):
    print("EVALUATING ON VAL")
    model = model.eval()
    sample_image(model,
                 embedder,
                 'outputs/pixelcnn',
                 n_row=4,
                 batches_done=1,
                 dataloader=test_loader,
                 device=torch.device('cuda'))
    # bpd = 0.0
    # for i, (imgs, labels, captions) in tqdm(enumerate(test_loader)):
    #     imgs = imgs.to(device)
    #     labels = labels.to(device)
    #     # add in current iterations/total iterations
    #     print(i,'/',len(test_loader))
    #     with torch.no_grad():
    #         condition_embd = embedder(labels, captions)
    #         outputs = model.forward(imgs, condition_embd)
    #         loss = outputs['loss'].mean()
    #         bpd += loss / np.log(2)
    # bpd /= len(test_loader)
    print("VAL bpd : {}".format(bpd))
    return bpd
def train(model_G, model_D, embedder, optimizer_G, optimizer_D, scheduler_G,
          scheduler_D, train_loader, val_loader, adv_loss, opt,
          onehot_encoder):
    print("TRAINING STARTS")
    for epoch in range(opt.n_epochs):
        model_G = model_G.train()
        model_D = model_D.train()
        loss_G_to_log = 0.0
        loss_D_to_log = 0.0
        valid = torch.FloatTensor(opt.batch_size, 1).fill_(1.0).to(device)
        fake = torch.FloatTensor(opt.batch_size, 1).fill_(0.0).to(device)
        for i, (imgs, labels, captions) in enumerate(train_loader):
            start_batch = time.time()
            imgs = imgs.to(device)
            labels = labels.to(device)
            if epoch < 10:
                d_noise = 0.2 * torch.FloatTensor(opt.batch_size,
                                                  1).fill_(1.0).to(device)
                d_noise[:int(opt.batch_size / 10), :] = 1.0
            elif epoch >= 10 and epoch < 20:
                d_noise = 0.1 * torch.FloatTensor(opt.batch_size,
                                                  1).fill_(1.0).to(device)
                d_noise[:int(opt.batch_size / 20), :] = 1.0
            else:
                d_noise = 0.1 * torch.FloatTensor(opt.batch_size,
                                                  1).fill_(1.0).to(device)
                d_noise[:int(opt.batch_size / 40), :] = 1.0

            with torch.no_grad():
                onehot = onehot_encoder(labels, captions)
                condition_embd = embedder(labels, captions)
                #                condition_embd = condition_embd[:, :opt.embed_dim]
                condition_embd = condition_embd.unsqueeze(dim=2).unsqueeze(
                    dim=3)
                #                condition_embd_expand = condition_embd.expand(-1, -1, 32, 32)
                condition_embd_expand = condition_embd

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()
            z = torch.randn(opt.batch_size, opt.z_dim, 1, 1).to(device)
            gen_imgs = model_G.forward(z, condition_embd)
            validity = model_D(gen_imgs, onehot)
            loss_G = adv_loss(validity, valid)
            loss_G.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            # Loss for real images
            validity_real = model_D(imgs, onehot)
            #            d_real_loss = adversarial_loss(validity_real, valid)
            d_real_loss = adversarial_loss(validity_real, valid - d_noise)

            # Loss for fake images
            validity_fake = model_D(gen_imgs.detach(), onehot)
            d_fake_loss = adversarial_loss(validity_fake, fake + d_noise)

            # Total discriminator loss
            loss_D = (d_real_loss + d_fake_loss) / 2

            loss_D.backward()
            optimizer_D.step()

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()
            z = torch.randn(opt.batch_size, opt.z_dim, 1, 1).to(device)
            gen_imgs = model_G.forward(z, condition_embd)
            validity = model_D(gen_imgs, onehot)
            loss_G = adv_loss(validity, valid)
            loss_G.backward()
            optimizer_G.step()

            batches_done = epoch * len(train_loader) + i
            writer.add_scalar('train/loss_G', loss_G.item(), batches_done)
            writer.add_scalar('train/loss_D', loss_D.item(), batches_done)
            loss_G_to_log += loss_G.item()
            loss_D_to_log += loss_D.item()
            if (i + 1) % opt.print_every == 0:
                loss_G_to_log = loss_G_to_log / opt.print_every
                loss_D_to_log = loss_D_to_log / opt.print_every
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [Loss G: %f] [Loss D: %f] [Time/batch %.3f]"
                    %
                    (epoch + 1, opt.n_epochs, i + 1, len(train_loader),
                     loss_G_to_log, loss_D_to_log, time.time() - start_batch))
                loss_G_to_log = 0.0
                loss_D_to_log = 0.0
            if (batches_done + 1) % opt.sample_interval == 0:
                print("sampling_images")
                model_G = model_G.eval()
                sample_image(model_G,
                             embedder,
                             opt.output_dir,
                             n_row=4,
                             batches_done=batches_done,
                             dataloader=val_loader,
                             device=device)
                model_G = model_G.train()

        if (epoch + 1) % 10 == 0:
            torch.save({
                'G': model_G.state_dict(),
                'D': model_D.state_dict()
            },
                       os.path.join(opt.output_dir, 'models',
                                    'epoch_{}.pt'.format(epoch)))

    scheduler_G.step()
    scheduler_D.step()
def train(model_G, model_D, embedder, optimizer_G, optimizer_D, scheduler_G,
          scheduler_D, train_loader, val_loader, adv_loss, opt):
    print("TRAINING STARTS")
    for epoch in range(opt.n_epochs):
        model_G = model_G.train()
        model_D = model_D.train()
        loss_G_to_log = 0.0
        loss_D_to_log = 0.0

        for i, (imgs, labels, captions) in enumerate(train_loader):
            start_batch = time.time()
            imgs = imgs
            labels = labels

            with torch.no_grad():
                condition_embd = embedder(labels, captions)
                condition_embd = condition_embd.unsqueeze(dim=2).unsqueeze(
                    dim=3)
                condition_embd_expand = condition_embd.expand(-1, -1, 32, 32)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            real_imgs = Variable(imgs.type(Tensor))
            #             z = torch.randn(opt.batch_size, opt.z_dim, 1, 1).to(device)
            z = Variable(
                Tensor(np.random.normal(0, 1,
                                        (imgs.shape[0], opt.z_dim, 1, 1))))
            gen_imgs = model_G.forward(z, condition_embd)
            #             gen_imgs = Variable(model_G.forward(z, condition_embd), requires_grad=True)
            validity_real = model_D(real_imgs, condition_embd_expand)
            validity_fake = model_D(gen_imgs, condition_embd_expand)
            gradient_penalty = compute_gradient_penalty(
                model_D, condition_embd_expand, real_imgs.data, gen_imgs.data)
            loss_D = -torch.mean(validity_real) + torch.mean(
                validity_fake) + 1 * gradient_penalty
            loss_D.backward()
            optimizer_D.step()

            # -----------------
            #  Train Generator
            # -----------------
            optimizer_G.zero_grad()

            #             z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.z_dim, 1, 1))))
            gen_imgs = model_G.forward(z, condition_embd)
            validity_fake = model_D(gen_imgs, condition_embd_expand)
            loss_G = -torch.mean(validity_fake)
            loss_G.backward()
            optimizer_G.step()

            # # -----------------
            # #  Train Generator
            # # -----------------
            # optimizer_G.zero_grad()
            # z = torch.randn(opt.batch_size, opt.z_dim, 1, 1).to(device)
            # gen_imgs = model_G.forward(z, condition_embd)
            # validity = model_D(gen_imgs, condition_embd_expand)
            # loss_G = adv_loss(validity, valid)
            # loss_G.backward()
            # optimizer_G.step()

            batches_done = epoch * len(train_loader) + i
            writer.add_scalar('train/loss_G', loss_G.item(), batches_done)
            writer.add_scalar('train/loss_D', loss_D.item(), batches_done)
            loss_G_to_log += loss_G.item()
            loss_D_to_log += loss_D.item()
            if (i + 1) % opt.print_every == 0:
                loss_G_to_log = loss_G_to_log / opt.print_every
                loss_D_to_log = loss_D_to_log / opt.print_every
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [Loss G: %f] [Loss D: %f] [Time/batch %.3f]"
                    %
                    (epoch + 1, opt.n_epochs, i + 1, len(train_loader),
                     loss_G_to_log, loss_D_to_log, time.time() - start_batch))
                loss_G_to_log = 0.0
                loss_D_to_log = 0.0
            if (batches_done + 1) % opt.sample_interval == 0:
                print("sampling_images")
                model_G = model_G.eval()
                sample_image(model_G,
                             embedder,
                             opt.output_dir,
                             n_row=4,
                             batches_done=batches_done,
                             dataloader=val_loader,
                             device=device)
                model_G = model_G.train()

        torch.save({
            'G': model_G.state_dict(),
            'D': model_D.state_dict()
        }, os.path.join(opt.output_dir, 'models', 'epoch_{}.pt'.format(epoch)))

    scheduler_G.step()
    scheduler_D.step()
Пример #7
0
def train(model,
          embedder,
          optimizer,
          scheduler,
          train_loader,
          val_loader,
          opt,
          writer,
          device=None):
    print("TRAINING STARTS")
    global global_step
    for epoch in range(opt.n_epochs):
        print("[Epoch %d/%d]" % (epoch + 1, opt.n_epochs))
        model = model.train()
        loss_to_log = 0.0
        loss_fn = util.NLLLoss().to(device)
        with tqdm(total=len(train_loader.dataset)) as progress_bar:
            for i, (imgs, labels, captions) in enumerate(train_loader):
                start_batch = time.time()
                imgs = imgs.to(device)
                labels = labels.to(device)

                with torch.no_grad():
                    if opt.conditioning == 'unconditional':
                        condition_embd = None
                    else:
                        condition_embd = embedder(labels, captions)

                optimizer.zero_grad()

                # outputs = model.forward(imgs, condition_embd)
                # loss = outputs['loss'].mean()
                # loss.backward()
                # optimizer.step()
                z, sldj = model.forward(imgs, condition_embd, reverse=False)
                loss = loss_fn(z, sldj) / np.prod(imgs.size()[1:])
                loss.backward()
                if opt.max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, opt.max_grad_norm)
                optimizer.step()
                scheduler.step(global_step)

                batches_done = epoch * len(train_loader) + i
                writer.add_scalar('train/bpd', loss / np.log(2), batches_done)
                loss_to_log += loss.item()
                # if (i + 1) % opt.print_every == 0:
                #     loss_to_log = loss_to_log / (np.log(2) * opt.print_every)
                #     print(
                #         "[Epoch %d/%d] [Batch %d/%d] [bpd: %f] [Time/batch %.3f]"
                #         % (epoch + 1, opt.n_epochs, i + 1, len(train_loader), loss_to_log, time.time() - start_batch)
                #     )
                progress_bar.set_postfix(bpd=(loss_to_log / np.log(2)),
                                         lr=optimizer.param_groups[0]['lr'])
                progress_bar.update(imgs.size(0))
                global_step += imgs.size(0)

                loss_to_log = 0.0

                if (batches_done + 1) % opt.sample_interval == 0:
                    print("sampling_images")
                    model = model.eval()
                    sample_image(model,
                                 embedder,
                                 opt.output_dir,
                                 n_row=4,
                                 batches_done=batches_done,
                                 dataloader=val_loader,
                                 device=device)

        val_bpd = eval(model, embedder, val_loader, opt, writer, device=device)
        writer.add_scalar("val/bpd", val_bpd, (epoch + 1) * len(train_loader))

        torch.save(
            model.state_dict(),
            os.path.join(opt.output_dir, 'models',
                         'epoch_{}.pt'.format(epoch)))
Пример #8
0
        train_dataset = CIFARDogDataset(train=not opt.train_on_val,
                                        max_size=1 if opt.debug else -1)
        val_dataset = CIFARDogDataset(train=0, max_size=1 if opt.debug else -1)
    elif opt.dataset == "cifarcat":
        train_dataset = CIFARCatDataset(train=not opt.train_on_val,
                                        max_size=1 if opt.debug else -1)
        val_dataset = CIFARCatDataset(train=0, max_size=1 if opt.debug else -1)
    elif opt.dataset == "cifarcatdog":
        train_dataset = CIFARCatDogDataset(train=not opt.train_on_val,
                                           max_size=1 if opt.debug else -1)
        val_dataset = CIFARCatDogDataset(train=0,
                                         max_size=1 if opt.debug else -1)
    else:
        raise Exception('Unknown dataset')

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    #     val_dataloader = torch.utils.data.DataLoader(
    #         val_dataset,
    #         batch_size=opt.batch_size,
    #         shuffle=True,
    #     )

    #     print("Len train : {}, val : {}".format(len(train_dataloader), len(val_dataloader)))

    sample_image(opt.n_row, train_dataloader)