示例#1
0
文件: main.py 项目: textstorm/VQ-VAE
def main(args):
    #
    save_dir = os.path.join(args.save_dir, args.model_type)
    img_dir = os.path.join(args.img_dir, args.model_type)
    log_dir = os.path.join(args.log_dir, args.model_type)
    train_dir = args.train_dir

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    mnist = utils.read_data_sets(args.train_dir)
    summary_writer = tf.summary.FileWriter(log_dir)
    config_proto = utils.get_config_proto()

    sess = tf.Session(config=config_proto)
    model = VQVAE(args, sess, name="vqvae")

    total_batch = mnist.train.num_examples // args.batch_size

    for epoch in range(1, args.nb_epoch + 1):
        print "Epoch %d start with learning rate %f" % (
            epoch, model.learning_rate.eval(sess))
        print "- " * 50
        epoch_start_time = time.time()
        step_start_time = epoch_start_time
        for i in range(1, total_batch + 1):
            global_step = sess.run(model.global_step)
            x_batch, y_batch = mnist.train.next_batch(args.batch_size)

            _, loss, rec_loss, vq, commit, global_step, summaries = model.train(
                x_batch)
            summary_writer.add_summary(summaries, global_step)

            if i % args.print_step == 0:
                print "epoch %d, step %d, loss %f, rec_loss %f, vq_loss %f, commit_loss %f, time %.2fs" \
                    % (epoch, global_step, loss, rec_loss, vq, commit, time.time()-step_start_time)
                step_start_time = time.time()

        if epoch % 50 == 0:
            print "- " * 5

        if args.anneal and epoch >= args.anneal_start:
            sess.run(model.lr_decay_op)

        if epoch % args.save_epoch == 0:
            x_batch, y_batch = mnist.test.next_batch(100)
            x_recon = model.reconstruct(x_batch)
            utils.save_images(x_batch.reshape(-1, 28, 28, 1), [10, 10],
                              os.path.join(img_dir, "rawImage%s.jpg" % epoch))
            utils.save_images(
                x_recon, [10, 10],
                os.path.join(img_dir, "reconstruct%s.jpg" % epoch))

    model.saver.save(sess, os.path.join(save_dir, "model.ckpt"))
    print "Model stored...."
示例#2
0
def main(args):
    #
    save_dir = os.path.join(args.save_dir, args.model_type)
    img_dir = os.path.join(args.img_dir, args.model_type)
    log_dir = os.path.join(args.log_dir, args.model_type)
    train_dir = args.train_dir

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    summary_writer = tf.summary.FileWriter(log_dir)
    config_proto = utils.get_config_proto()

    sess = tf.Session(config=config_proto)
    model = VQVAE(args, sess, name="vqvae")

    img_paths = glob.glob('data/img_align_celeba/*.jpg')
    train_paths, test_paths = train_test_split(img_paths,
                                               test_size=0.1,
                                               random_state=args.random_seed)
    celeba = utils.DiskImageData(sess,
                                 train_paths,
                                 args.batch_size,
                                 shape=[218, 178, 3])
    total_batch = celeba.num_examples // args.batch_size

    for epoch in range(1, args.nb_epoch + 1):
        print "Epoch %d start with learning rate %f" % (
            epoch, model.learning_rate.eval(sess))
        print "- " * 50
        epoch_start_time = time.time()
        step_start_time = epoch_start_time
        for i in range(1, total_batch + 1):
            global_step = sess.run(model.global_step)
            x_batch = celeba.next_batch()

            _, loss, rec_loss, vq, commit, global_step, summaries = model.train(
                x_batch)
            summary_writer.add_summary(summaries, global_step)

            if i % args.print_step == 0:
                print "epoch %d, step %d, loss %f, rec_loss %f, vq_loss %f, commit_loss %f, time %.2fs" \
                    % (epoch, global_step, loss, rec_loss, vq, commit, time.time()-step_start_time)
                step_start_time = time.time()

        if args.anneal and epoch >= args.anneal_start:
            sess.run(model.lr_decay_op)

        if epoch % args.save_epoch == 0:
            x_batch = celeba.next_batch()
            x_recon = model.reconstruct(x_batch)
            utils.save_images(x_batch, [10, 10],
                              os.path.join(img_dir, "rawImage%s.jpg" % epoch))
            utils.save_images(
                x_recon, [10, 10],
                os.path.join(img_dir, "reconstruct%s.jpg" % epoch))

    model.saver.save(sess, os.path.join(save_dir, "model.ckpt"))
    print "Model stored...."
示例#3
0
    loader = tqdm(loader)

    for i, (img, _) in enumerate(loader):
        img = img.cuda()

        #generate the attention regions for the images
        saliency = utilities.compute_saliency_maps(img, model)

        norm = transforms.Normalize((-1, -1, -1), (2, 2, 2))

        blackout_img = img.clone()

        for i in range(len(saliency)):
            normalised = norm(saliency[i])
            blackout_img[i] = torch.mul(normalised, img[i])

        #train model here
        model.train()
        model.zero_grad()
        output = model(blackout_img)
        loss = criterion(img, output)

        loss.backward()
        optimizer.step()

    print("EPOCH: ", epoch + 1, "Loss: ", loss)
    torch.save(model.state_dict(), "Models/toy/" + str(epoch + 1) + ".pt")
    utilities.show_tensor(img[0], False, _)
    utilities.show_tensor(output[0], False, _)
示例#4
0
def train_vqvae(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = VQVAE(args.channels, args.latent_dim, args.num_embeddings,
                  args.embedding_dim)
    model.to(device)

    model_name = "{}_C_{}_N_{}_M_{}_D_{}".format(args.model, args.channels,
                                                 args.latent_dim,
                                                 args.num_embeddings,
                                                 args.embedding_dim)

    checkpoint_dir = Path(model_name)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    writer = SummaryWriter(log_dir=Path("runs") / model_name)

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    if args.resume is not None:
        print("Resume checkpoint from: {}:".format(args.resume))
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        global_step = checkpoint["step"]
    else:
        global_step = 0

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(shift)])
    training_dataset = datasets.CIFAR10("./CIFAR10",
                                        train=True,
                                        download=True,
                                        transform=transform)

    test_dataset = datasets.CIFAR10("./CIFAR10",
                                    train=False,
                                    download=True,
                                    transform=transform)

    training_dataloader = DataLoader(training_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=True)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=64,
                                 shuffle=True,
                                 drop_last=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    num_epochs = args.num_training_steps // len(training_dataloader) + 1
    start_epoch = global_step // len(training_dataloader) + 1

    N = 3 * 32 * 32
    KL = args.latent_dim * 8 * 8 * np.log(args.num_embeddings)

    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0
        for i, (images, _) in enumerate(tqdm(training_dataloader), 1):
            images = images.to(device)

            dist, vq_loss, perplexity = model(images)
            targets = (images + 0.5) * 255
            targets = targets.long()
            logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
            loss = -logp / N + vq_loss
            elbo = (KL - logp) / N
            bpd = elbo / np.log(2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            global_step += 1

            if global_step % 25000 == 0:
                save_checkpoint(model, optimizer, global_step, checkpoint_dir)

            average_logp += (logp.item() - average_logp) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_elbo += (elbo.item() - average_elbo) / i
            average_bpd += (bpd.item() - average_bpd) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

        writer.add_scalar("logp/train", average_logp, epoch)
        writer.add_scalar("kl/train", KL, epoch)
        writer.add_scalar("vqloss/train", average_vq_loss, epoch)
        writer.add_scalar("elbo/train", average_elbo, epoch)
        writer.add_scalar("bpd/train", average_bpd, epoch)
        writer.add_scalar("perplexity/train", average_perplexity, epoch)

        model.eval()
        average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0
        for i, (images, _) in enumerate(test_dataloader, 1):
            images = images.to(device)

            with torch.no_grad():
                dist, vq_loss, perplexity = model(images)

            targets = (images + 0.5) * 255
            targets = targets.long()
            logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
            elbo = (KL - logp) / N
            bpd = elbo / np.log(2)

            average_logp += (logp.item() - average_logp) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_elbo += (elbo.item() - average_elbo) / i
            average_bpd += (bpd.item() - average_bpd) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

        writer.add_scalar("logp/test", average_logp, epoch)
        writer.add_scalar("kl/test", KL, epoch)
        writer.add_scalar("vqloss/test", average_vq_loss, epoch)
        writer.add_scalar("elbo/test", average_elbo, epoch)
        writer.add_scalar("bpd/test", average_bpd, epoch)
        writer.add_scalar("perplexity/test", average_perplexity, epoch)

        samples = torch.argmax(dist.logits, dim=-1)
        grid = utils.make_grid(samples.float() / 255)
        writer.add_image("reconstructions", grid, epoch)

        print(
            "epoch:{}, logp:{:.3E}, vq loss:{:.3E}, elbo:{:.3f}, bpd:{:.3f}, perplexity:{:.3f}"
            .format(epoch, average_logp, average_vq_loss, average_elbo,
                    average_bpd, average_perplexity))