Exemplo n.º 1
0
def train_vqvae(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = VQVAE(args.channels, args.latent_dim, args.num_embeddings,
                  args.embedding_dim)
    model.to(device)

    model_name = "{}_C_{}_N_{}_M_{}_D_{}".format(args.model, args.channels,
                                                 args.latent_dim,
                                                 args.num_embeddings,
                                                 args.embedding_dim)

    checkpoint_dir = Path(model_name)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    writer = SummaryWriter(log_dir=Path("runs") / model_name)

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    if args.resume is not None:
        print("Resume checkpoint from: {}:".format(args.resume))
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        global_step = checkpoint["step"]
    else:
        global_step = 0

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(shift)])
    training_dataset = datasets.CIFAR10("./CIFAR10",
                                        train=True,
                                        download=True,
                                        transform=transform)

    test_dataset = datasets.CIFAR10("./CIFAR10",
                                    train=False,
                                    download=True,
                                    transform=transform)

    training_dataloader = DataLoader(training_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=True)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=64,
                                 shuffle=True,
                                 drop_last=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    num_epochs = args.num_training_steps // len(training_dataloader) + 1
    start_epoch = global_step // len(training_dataloader) + 1

    N = 3 * 32 * 32
    KL = args.latent_dim * 8 * 8 * np.log(args.num_embeddings)

    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0
        for i, (images, _) in enumerate(tqdm(training_dataloader), 1):
            images = images.to(device)

            dist, vq_loss, perplexity = model(images)
            targets = (images + 0.5) * 255
            targets = targets.long()
            logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
            loss = -logp / N + vq_loss
            elbo = (KL - logp) / N
            bpd = elbo / np.log(2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            global_step += 1

            if global_step % 25000 == 0:
                save_checkpoint(model, optimizer, global_step, checkpoint_dir)

            average_logp += (logp.item() - average_logp) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_elbo += (elbo.item() - average_elbo) / i
            average_bpd += (bpd.item() - average_bpd) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

        writer.add_scalar("logp/train", average_logp, epoch)
        writer.add_scalar("kl/train", KL, epoch)
        writer.add_scalar("vqloss/train", average_vq_loss, epoch)
        writer.add_scalar("elbo/train", average_elbo, epoch)
        writer.add_scalar("bpd/train", average_bpd, epoch)
        writer.add_scalar("perplexity/train", average_perplexity, epoch)

        model.eval()
        average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0
        for i, (images, _) in enumerate(test_dataloader, 1):
            images = images.to(device)

            with torch.no_grad():
                dist, vq_loss, perplexity = model(images)

            targets = (images + 0.5) * 255
            targets = targets.long()
            logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
            elbo = (KL - logp) / N
            bpd = elbo / np.log(2)

            average_logp += (logp.item() - average_logp) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_elbo += (elbo.item() - average_elbo) / i
            average_bpd += (bpd.item() - average_bpd) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

        writer.add_scalar("logp/test", average_logp, epoch)
        writer.add_scalar("kl/test", KL, epoch)
        writer.add_scalar("vqloss/test", average_vq_loss, epoch)
        writer.add_scalar("elbo/test", average_elbo, epoch)
        writer.add_scalar("bpd/test", average_bpd, epoch)
        writer.add_scalar("perplexity/test", average_perplexity, epoch)

        samples = torch.argmax(dist.logits, dim=-1)
        grid = utils.make_grid(samples.float() / 255)
        writer.add_image("reconstructions", grid, epoch)

        print(
            "epoch:{}, logp:{:.3E}, vq loss:{:.3E}, elbo:{:.3f}, bpd:{:.3f}, perplexity:{:.3f}"
            .format(epoch, average_logp, average_vq_loss, average_elbo,
                    average_bpd, average_perplexity))
Exemplo n.º 2
0
from model import VQVAE

args = config.get_args()
transform = config.get_transform()

dataset = datasets.ImageFolder(args.path, transform=transform)
loader = DataLoader(dataset,
                    batch_size=args.batch,
                    shuffle=True,
                    num_workers=0)

model = VQVAE()
model = model.cuda()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

from torch.autograd import Variable

for epoch in range(args.epoch):

    loader = tqdm(loader)

    for i, (img, _) in enumerate(loader):
        img = img.cuda()

        #generate the attention regions for the images
        saliency = utilities.compute_saliency_maps(img, model)

        norm = transforms.Normalize((-1, -1, -1), (2, 2, 2))
Exemplo n.º 3
0
train_loader = torch.utils.data.DataLoader(datasets.MNIST(
    '../data', train=True, download=True, transform=transforms.ToTensor()),
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.MNIST(
    '../data', train=False, transform=transforms.ToTensor()),
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          **kwargs)

model = VQVAE(args.input_dim, args.emb_dim, args.emb_num, args.batch_size)
if args.cuda:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train(epoch):
    """run one epoch of model to train with data loader"""
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data).view(-1, 784)
        if args.cuda:
            data = data.cuda()
        # run forward
        # compute losses
        recon_batch, reconst_loss, embed_loss, commit_loss = model(data)

        # clear gradients and run backward