예제 #1
0
    def test(self):
        states = torch.load(
            os.path.join(self.args.log, "checkpoint.pth"),
            map_location=self.config.device,
        )
        decoder = (MLPDecoder(self.config).to(self.config.device)
                   if self.config.data.dataset == "MNIST" else Decoder(
                       self.config).to(self.config.device))
        decoder.eval()
        decoder.load_state_dict(states[1])
        z = torch.randn(100,
                        self.config.model.z_dim,
                        device=self.config.device)
        if self.config.data.dataset == "CELEBA":
            samples, _ = decoder(z)
            samples = samples.view(
                100,
                self.config.data.channels,
                self.config.data.image_size,
                self.config.data.image_size,
            )
            image_grid = make_grid(samples, 10)
            image_grid = torch.clamp(image_grid / 2.0 + 0.5, 0.0, 1.0)
        elif self.config.data.dataset == "MNIST":
            samples_logits = decoder(z)
            samples = torch.sigmoid(samples_logits)
            samples = samples.view(
                100,
                self.config.data.channels,
                self.config.data.image_size,
                self.config.data.image_size,
            )
            image_grid = make_grid(samples, 10)

        save_image(image_grid, "image_grid.png")
예제 #2
0
    def test_svi(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == "CIFAR10":
            test_dataset = CIFAR10(
                os.path.join(self.args.run, "datasets", "cifar10"),
                train=False,
                download=True,
                transform=transform,
            )
        elif self.config.data.dataset == "MNIST":
            test_dataset = MNIST(
                os.path.join(self.args.run, "datasets", "mnist"),
                train=False,
                download=True,
                transform=transform,
            )
        elif self.config.data.dataset == "CELEBA":
            dataset = ImageFolder(
                # root="/raid/tianyu/ncsn/run/datasets/celeba/celeba",
                root="/home/kunxu/tmp/",
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
            )
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            test_indices = indices[int(0.8 * num_items):]
            test_dataset = Subset(dataset, test_indices)

        test_loader = DataLoader(
            test_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=False,
            num_workers=2,
        )

        self.config.input_dim = (self.config.data.image_size**2 *
                                 self.config.data.channels)

        states = torch.load(
            os.path.join(self.args.log, "checkpoint.pth"),
            map_location=self.config.device,
        )
        decoder = MLPDecoder(self.config).to(self.config.device)
        if self.config.training.algo == "vae":
            encoder = MLPEncoder(self.config).to(self.config.device)
            encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
        elif self.config.training.algo in ["ssm", "ssm_fd"]:
            score = Score(self.config).to(self.config.device)
            imp_encoder = MLPImplicitEncoder(self.config).to(
                self.config.device)
            imp_encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
            score.load_state_dict(states[2])

        total_l = 0.0
        total_n = 0
        recon_type = "bernoulli" if self.config.data.dataset == "MNIST" else "gaussian"
        from models.gmm import Gaussian4SVI

        for batch, (X, y) in enumerate(test_loader):
            X = X.to(self.config.device)
            if self.config.data.dataset == "CELEBA":
                X = X + (torch.rand_like(X) - 0.5) / 128.0
            elif self.config.data.dataset == "MNIST":
                eps = torch.rand_like(X)
                X = (eps <= X).float()

            gaussian = Gaussian4SVI(X.shape[0], self.config.model.z_dim).to(
                self.config.device)
            optimizer = optim.SGD(gaussian.parameters(), lr=0.01, momentum=0.5)
            lr_scheduler = optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=[100, 200], gamma=0.3)
            for i in range(300):
                lr_scheduler.step()
                loss = iwae(gaussian,
                            decoder,
                            X,
                            type=recon_type,
                            k=10,
                            training=True)
                decoder.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                print(i, loss.item())
            loss = iwae(gaussian,
                        decoder,
                        X,
                        type=recon_type,
                        k=10,
                        training=False)
            total_l += loss.item() * X.shape[0]
            total_n += X.shape[0]
            print("mini-batch: {}, current iwae-10: {}, average iwae-10: {}".
                  format(batch + 1, loss.item(), total_l / total_n))
예제 #3
0
    def test_ais(self):
        assert self.config.data.dataset == "MNIST"
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == "CIFAR10":
            test_dataset = CIFAR10(
                os.path.join(self.args.run, "datasets", "cifar10"),
                train=False,
                download=True,
                transform=transform,
            )
        elif self.config.data.dataset == "MNIST":
            test_dataset = MNIST(
                os.path.join(self.args.run, "datasets", "mnist"),
                train=False,
                download=True,
                transform=transform,
            )
        elif self.config.data.dataset == "CELEBA":
            dataset = ImageFolder(
                # root="/raid/tianyu/ncsn/run/datasets/celeba/celeba",
                root="/home/kunxu/tmp/",
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
            )
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            test_indices = indices[int(0.8 * num_items):]
            test_dataset = Subset(dataset, test_indices)

        # test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=False,
        #                          num_workers=2)

        test_loader = DataLoader(test_dataset,
                                 batch_size=32,
                                 shuffle=False,
                                 num_workers=2)

        self.config.input_dim = (self.config.data.image_size**2 *
                                 self.config.data.channels)

        states = torch.load(
            os.path.join(self.args.log, "checkpoint_100k.pth"),
            map_location=self.config.device,
        )
        decoder = MLPDecoder(self.config).to(self.config.device)
        if self.config.training.algo == "vae":
            encoder = MLPEncoder(self.config).to(self.config.device)
            encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
        elif self.config.training.algo in ["ssm", "ssm_fd"]:
            score = MLPScore(self.config).to(self.config.device)
            imp_encoder = MLPImplicitEncoder(self.config).to(
                self.config.device)
            imp_encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
            score.load_state_dict(states[2])
        elif self.config.training.algo in ["spectral", "stein"]:
            imp_encoder = MLPImplicitEncoder(self.config).to(
                self.config.device)
            imp_encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])

        recon_type = "bernoulli" if self.config.data.dataset == "MNIST" else "gaussian"

        def recon_energy(X, z):
            if recon_type is "gaussian":
                mean_x, logstd_x = decoder(z)
                recon = ((X - mean_x)**2 / (2.0 * (2 * logstd_x).exp()) +
                         np.log(2.0 * np.pi) / 2.0 + logstd_x)
                recon = recon.sum(dim=(1, 2, 3))
            elif recon_type is "bernoulli":
                x_logits = decoder(z)
                recon = F.binary_cross_entropy_with_logits(input=x_logits,
                                                           target=X,
                                                           reduction="none")
                recon = recon.sum(dim=[1, 2, 3])
            return recon

        from evaluations.ais import AISLatentVariableModels

        ais = AISLatentVariableModels(recon_energy,
                                      self.config.model.z_dim,
                                      self.config.device,
                                      n_Ts=1000)

        total_l = 0.0
        total_n = 0
        for _, (X, y) in enumerate(test_loader):
            X = X.to(self.config.device)
            if self.config.data.dataset == "CELEBA":
                X = X + (torch.rand_like(X) - 0.5) / 128.0
            elif self.config.data.dataset == "MNIST":
                eps = torch.rand_like(X)
                X = (eps <= X).float()

            ais_lb = ais.ais(X).mean().item()
            total_l += ais_lb * X.shape[0]
            total_n += X.shape[0]
            print("current ais lb: {}, mean ais lb: {}".format(
                ais_lb, total_l / total_n))
예제 #4
0
    def train(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == "CIFAR10":
            dataset = CIFAR10(
                os.path.join(self.args.run, "datasets", "cifar10"),
                train=True,
                download=True,
                transform=transform,
            )
            test_dataset = CIFAR10(
                os.path.join(self.args.run, "datasets", "cifar10"),
                train=False,
                download=True,
                transform=transform,
            )
        elif self.config.data.dataset == "MNIST":
            dataset = MNIST(
                os.path.join(self.args.run, "datasets", "mnist"),
                train=True,
                download=True,
                transform=transform,
            )
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = (
                indices[:int(num_items * 0.8)],
                indices[int(num_items * 0.8):],
            )
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        elif self.config.data.dataset == "CELEBA":
            dataset = ImageFolder(
                # root="/raid/tianyu/ncsn/run/datasets/celeba/celeba",
                root="/home/kunxu/tmp/",
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
            )
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = (
                indices[:int(num_items * 0.7)],
                indices[int(num_items * 0.7):int(num_items * 0.8)],
            )
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        dataloader = DataLoader(
            dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            num_workers=4,
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            num_workers=2,
        )
        test_iter = iter(test_loader)
        self.config.input_dim = (self.config.data.image_size**2 *
                                 self.config.data.channels)

        tb_path = os.path.join(self.args.run, "tensorboard", self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        decoder = (MLPDecoder(self.config).to(self.config.device)
                   if self.config.data.dataset == "MNIST" else Decoder(
                       self.config).to(self.config.device))
        if self.config.training.algo == "vae":
            encoder = (MLPEncoder(self.config).to(self.config.device)
                       if self.config.data.dataset == "MNIST" else Encoder(
                           self.config).to(self.config.device))
            optimizer = self.get_optimizer(
                itertools.chain(encoder.parameters(), decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(
                    os.path.join(self.args.log, "checkpoint.pth"))
                encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])
        elif self.config.training.algo in ["ssm", "ssm_fd"]:
            score = (MLPScore(self.config).to(self.config.device)
                     if self.config.data.dataset == "MNIST" else Score(
                         self.config).to(self.config.device))
            imp_encoder = (MLPImplicitEncoder(self.config).to(
                self.config.device) if self.config.data.dataset == "MNIST" else
                           ImplicitEncoder(self.config).to(self.config.device))

            opt_ae = optim.RMSprop(
                itertools.chain(decoder.parameters(),
                                imp_encoder.parameters()),
                lr=self.config.optim.lr,
            )
            opt_score = optim.RMSprop(score.parameters(),
                                      lr=self.config.optim.lr)
            if self.args.resume_training:
                states = torch.load(
                    os.path.join(self.args.log, "checkpoint.pth"))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                score.load_state_dict(states[2])
                opt_ae.load_state_dict(states[3])
                opt_score.load_state_dict(states[4])
        elif self.config.training.algo in ["spectral", "stein"]:
            from models.kernel_score_estimators import (
                SpectralScoreEstimator,
                SteinScoreEstimator,
            )

            imp_encoder = (MLPImplicitEncoder(self.config).to(
                self.config.device) if self.config.data.dataset == "MNIST" else
                           ImplicitEncoder(self.config).to(self.config.device))
            estimator = (SpectralScoreEstimator() if self.config.training.algo
                         == "spectral" else SteinScoreEstimator())
            optimizer = self.get_optimizer(
                itertools.chain(imp_encoder.parameters(),
                                decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(
                    os.path.join(self.args.log, "checkpoint.pth"))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])

        step = 0
        best_validation_loss = np.inf
        validation_losses = []
        recon_type = "bernoulli" if self.config.data.dataset == "MNIST" else "gaussian"
        time_dur = 0.0
        for _ in range(self.config.training.n_epochs):
            for _, (X, y) in enumerate(dataloader):
                decoder.train()
                X = X.to(self.config.device)
                if self.config.data.dataset == "CELEBA":
                    X = X + (torch.rand_like(X) - 0.5) / 128.0
                elif self.config.data.dataset == "MNIST":
                    eps = torch.rand_like(X)
                    X = (eps <= X).float()

                if self.config.training.algo == "vae":
                    encoder.train()
                    loss, *_ = elbo(encoder, decoder, X, recon_type)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                elif self.config.training.algo == "ssm":
                    imp_encoder.train()
                    loss, run_time, ssm_loss, *_ = elbo_ssm(
                        imp_encoder,
                        decoder,
                        score,
                        opt_score,
                        X,
                        recon_type,
                        training=True,
                        n_particles=self.config.model.n_particles,
                    )
                    opt_ae.zero_grad()
                    loss.backward()
                    opt_ae.step()
                elif self.config.training.algo == "ssm_fd":
                    imp_encoder.train()
                    loss, run_time, ssm_loss, *_ = elbo_ssm_fd(
                        imp_encoder,
                        decoder,
                        score,
                        opt_score,
                        X,
                        recon_type,
                        training=True,
                        n_particles=self.config.model.n_particles,
                    )
                    opt_ae.zero_grad()
                    loss.backward()
                    opt_ae.step()
                elif self.config.training.algo in ["spectral", "stein"]:
                    imp_encoder.train()
                    loss = elbo_kernel(
                        imp_encoder,
                        decoder,
                        estimator,
                        X,
                        recon_type,
                        n_particles=self.config.model.n_particles,
                    )
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                time_dur += run_time

                if step % 10 == 0:
                    try:
                        test_X, _ = next(test_iter)
                    except:
                        test_iter = iter(test_loader)
                        test_X, _ = next(test_iter)

                    test_X = test_X.to(self.config.device)
                    if self.config.data.dataset == "CELEBA":
                        test_X = test_X + (torch.rand_like(test_X) -
                                           0.5) / 128.0
                    elif self.config.data.dataset == "MNIST":
                        test_eps = torch.rand_like(test_X)
                        test_X = (test_eps <= test_X).float()

                    decoder.eval()
                    if self.config.training.algo == "vae":
                        encoder.eval()
                        with torch.no_grad():
                            test_loss, *_ = elbo(encoder, decoder, test_X,
                                                 recon_type)
                            logging.info("loss: {}, test_loss: {}".format(
                                loss.item(), test_loss.item()))
                    elif self.config.training.algo == "ssm":
                        imp_encoder.eval()
                        test_loss, *_ = elbo_ssm(
                            imp_encoder,
                            decoder,
                            score,
                            None,
                            test_X,
                            recon_type,
                            training=False,
                        )
                        logging.info(
                            "loss: {}, ssm_loss: {}, test_loss: {}".format(
                                loss.item(), ssm_loss.item(),
                                test_loss.item()))
                        z = imp_encoder(test_X)
                        tb_logger.add_histogram("z_X", z, global_step=step)
                    elif self.config.training.algo == "ssm_fd":
                        imp_encoder.eval()
                        test_loss, *_ = elbo_ssm_fd(
                            imp_encoder,
                            decoder,
                            score,
                            None,
                            test_X,
                            recon_type,
                            training=False,
                        )
                        logging.info(
                            "loss: {}, ssm_loss: {}, test_loss: {}".format(
                                loss.item(), ssm_loss.item(),
                                test_loss.item()))
                        z = imp_encoder(test_X)
                        tb_logger.add_histogram("z_X", z, global_step=step)
                    elif self.config.training.algo in ["spectral", "stein"]:
                        imp_encoder.eval()
                        with torch.no_grad():
                            test_loss = elbo_kernel(imp_encoder, decoder,
                                                    estimator, test_X,
                                                    recon_type, 10)

                            logging.info("loss: {}, test_loss: {}".format(
                                loss.item(), test_loss.item()))

                    validation_losses.append(test_loss.item())
                    tb_logger.add_scalar("loss", loss, global_step=step)
                    tb_logger.add_scalar("test_loss",
                                         test_loss,
                                         global_step=step)

                    if self.config.training.algo in ["ssm", "ssm_fd"]:
                        tb_logger.add_scalar("ssm_loss",
                                             ssm_loss,
                                             global_step=step)

                if step % 500 == 0:
                    logging.info(
                        "Time Dur in this 500 iters: {}".format(time_dur))
                    time_dur = 0.0
                    with torch.no_grad():
                        z = torch.randn(100,
                                        self.config.model.z_dim,
                                        device=X.device)
                        decoder.eval()
                        if self.config.data.dataset == "CELEBA":
                            samples, _ = decoder(z)
                            samples = samples.view(
                                100,
                                self.config.data.channels,
                                self.config.data.image_size,
                                self.config.data.image_size,
                            )
                            image_grid = make_grid(samples, 10)
                            image_grid = torch.clamp(image_grid / 2.0 + 0.5,
                                                     0.0, 1.0)
                            data_grid = make_grid(X[:100], 10)
                            data_grid = torch.clamp(data_grid / 2.0 + 0.5, 0.0,
                                                    1.0)
                        elif self.config.data.dataset == "MNIST":
                            samples_logits = decoder(z)
                            samples = torch.sigmoid(samples_logits)
                            samples = samples.view(
                                100,
                                self.config.data.channels,
                                self.config.data.image_size,
                                self.config.data.image_size,
                            )
                            image_grid = make_grid(samples, 10)
                            data_grid = make_grid(X[:100], 10)

                        tb_logger.add_image("samples",
                                            image_grid,
                                            global_step=step)
                        tb_logger.add_image("data",
                                            data_grid,
                                            global_step=step)

                        if len(validation_losses) != 0:
                            validation_loss = sum(validation_losses) / len(
                                validation_losses)
                            if validation_loss < best_validation_loss:
                                best_validation_loss = validation_loss
                                validation_losses = []
                            # else:
                            #     return 0

                if (step + 1) % 10000 == 0:
                    if self.config.training.algo == "vae":
                        states = [
                            encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict(),
                        ]
                    elif self.config.training.algo in ["ssm", "ssm_fd"]:
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            score.state_dict(),
                            opt_ae.state_dict(),
                            opt_score.state_dict(),
                        ]
                    elif self.config.training.algo in ["spectral", "stein"]:
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict(),
                        ]
                    torch.save(
                        states,
                        os.path.join(
                            self.args.log,
                            "checkpoint_{}0k.pth".format((step + 1) // 10000),
                        ),
                    )
                    torch.save(states,
                               os.path.join(self.args.log, "checkpoint.pth"))

                step += 1
                if step >= self.config.training.n_iters:
                    return 0
예제 #5
0
    def test_ais(self):
        assert self.config.data.dataset == 'MNIST'
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True,
                                 transform=transform)
        elif self.config.data.dataset == 'CELEBA':
            dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'),
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            test_indices = indices[int(0.8 * num_items):]
            test_dataset = Subset(dataset, test_indices)

        # test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=False,
        #                          num_workers=2)

        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                                 num_workers=2)

        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
        decoder = MLPDecoder(self.config).to(self.config.device)
        if self.config.training.algo == 'vae':
            encoder = MLPEncoder(self.config).to(self.config.device)
            encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
        elif self.config.training.algo == 'ssm':
            score = MLPScore(self.config).to(self.config.device)
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device)
            imp_encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
            score.load_state_dict(states[2])
        elif self.config.training.algo in ['spectral', 'stein']:
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device)
            imp_encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])

        recon_type = 'bernoulli' if self.config.data.dataset == 'MNIST' else 'gaussian'

        def recon_energy(X, z):
            if recon_type is 'gaussian':
                mean_x, logstd_x = decoder(z)
                recon = (X - mean_x) ** 2 / (2. * (2 * logstd_x).exp()) + np.log(2. * np.pi) / 2. + logstd_x
                recon = recon.sum(dim=(1, 2, 3))
            elif recon_type is 'bernoulli':
                x_logits = decoder(z)
                recon = F.binary_cross_entropy_with_logits(input=x_logits, target=X, reduction='none')
                recon = recon.sum(dim=[1, 2, 3])
            return recon

        from evaluations.ais import AISLatentVariableModels
        ais = AISLatentVariableModels(recon_energy,
                                      self.config.model.z_dim,
                                      self.config.device, n_Ts=1000)

        total_l = 0.
        total_n = 0
        for _, (X, y) in enumerate(test_loader):
            X = X.to(self.config.device)
            if self.config.data.dataset == 'CELEBA':
                X = X + (torch.rand_like(X) - 0.5) / 128.
            elif self.config.data.dataset == 'MNIST':
                eps = torch.rand_like(X)
                X = (eps <= X).float()

            ais_lb = ais.ais(X).mean().item()
            total_l += ais_lb * X.shape[0]
            total_n += X.shape[0]
            print('current ais lb: {}, mean ais lb: {}'.format(ais_lb, total_l / total_n))
예제 #6
0
    def train(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True,
                              transform=transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True,
                            transform=transform)
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = indices[:int(num_items * 0.8)], indices[int(num_items * 0.8):]
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        elif self.config.data.dataset == 'CELEBA':
            dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'),
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = indices[:int(num_items * 0.7)], indices[
                                                                          int(num_items * 0.7):int(num_items * 0.8)]
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=2)
        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        decoder = MLPDecoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
            else Decoder(self.config).to(self.config.device)
        if self.config.training.algo == 'vae':
            encoder = MLPEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else Encoder(self.config).to(self.config.device)
            optimizer = self.get_optimizer(itertools.chain(encoder.parameters(), decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])
        elif self.config.training.algo == 'ssm':
            score = MLPScore(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' else \
                Score(self.config).to(self.config.device)
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else ImplicitEncoder(self.config).to(self.config.device)

            opt_ae = optim.RMSprop(itertools.chain(decoder.parameters(), imp_encoder.parameters()),
                                   lr=self.config.optim.lr)
            opt_score = optim.RMSprop(score.parameters(), lr=self.config.optim.lr)
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                score.load_state_dict(states[2])
                opt_ae.load_state_dict(states[3])
                opt_score.load_state_dict(states[4])
        elif self.config.training.algo in ['spectral', 'stein']:
            from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else ImplicitEncoder(self.config).to(self.config.device)
            estimator = SpectralScoreEstimator() if self.config.training.algo == 'spectral' else SteinScoreEstimator()
            optimizer = self.get_optimizer(itertools.chain(imp_encoder.parameters(), decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])

        step = 0
        best_validation_loss = np.inf
        validation_losses = []
        recon_type = 'bernoulli' if self.config.data.dataset == 'MNIST' else 'gaussian'

        for _ in range(self.config.training.n_epochs):
            for _, (X, y) in enumerate(dataloader):
                decoder.train()
                X = X.to(self.config.device)
                if self.config.data.dataset == 'CELEBA':
                    X = X + (torch.rand_like(X) - 0.5) / 128.
                elif self.config.data.dataset == 'MNIST':
                    eps = torch.rand_like(X)
                    X = (eps <= X).float()

                if self.config.training.algo == 'vae':
                    encoder.train()
                    loss, *_ = elbo(encoder, decoder, X, recon_type)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                elif self.config.training.algo == 'ssm':
                    imp_encoder.train()
                    loss, ssm_loss, *_ = elbo_ssm(imp_encoder, decoder, score, opt_score, X, recon_type,
                                                  training=True, n_particles=self.config.model.n_particles)
                    opt_ae.zero_grad()
                    loss.backward()
                    opt_ae.step()
                elif self.config.training.algo in ['spectral', 'stein']:
                    imp_encoder.train()
                    loss = elbo_kernel(imp_encoder, decoder, estimator, X, recon_type,
                                       n_particles=self.config.model.n_particles)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                if step % 10 == 0:
                    try:
                        test_X, _ = next(test_iter)
                    except:
                        test_iter = iter(test_loader)
                        test_X, _ = next(test_iter)

                    test_X = test_X.to(self.config.device)
                    if self.config.data.dataset == 'CELEBA':
                        test_X = test_X + (torch.rand_like(test_X) - 0.5) / 128.
                    elif self.config.data.dataset == 'MNIST':
                        test_eps = torch.rand_like(test_X)
                        test_X = (test_eps <= test_X).float()

                    decoder.eval()
                    if self.config.training.algo == 'vae':
                        encoder.eval()
                        with torch.no_grad():
                            test_loss, *_ = elbo(encoder, decoder, test_X, recon_type)
                            logging.info("loss: {}, test_loss: {}".format(loss.item(), test_loss.item()))
                    elif self.config.training.algo == 'ssm':
                        imp_encoder.eval()
                        test_loss, *_ = elbo_ssm(imp_encoder, decoder, score, None, test_X, recon_type, training=False)
                        logging.info("loss: {}, ssm_loss: {}, test_loss: {}".format(loss.item(), ssm_loss.item(),
                                                                                    test_loss.item()))
                        z = imp_encoder(test_X)
                        tb_logger.add_histogram('z_X', z, global_step=step)
                    elif self.config.training.algo in ['spectral', 'stein']:
                        imp_encoder.eval()
                        with torch.no_grad():
                            test_loss = elbo_kernel(imp_encoder, decoder, estimator, test_X, recon_type, 10)

                            logging.info("loss: {}, test_loss: {}".format(loss.item(), test_loss.item()))

                    validation_losses.append(test_loss.item())
                    tb_logger.add_scalar('loss', loss, global_step=step)
                    tb_logger.add_scalar('test_loss', test_loss, global_step=step)

                    if self.config.training.algo == 'ssm':
                        tb_logger.add_scalar('ssm_loss', ssm_loss, global_step=step)

                if step % 500 == 0:
                    with torch.no_grad():
                        z = torch.randn(100, self.config.model.z_dim, device=X.device)
                        decoder.eval()
                        if self.config.data.dataset == 'CELEBA':
                            samples, _ = decoder(z)
                            samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                                   self.config.data.image_size)
                            image_grid = make_grid(samples, 10)
                            image_grid = torch.clamp(image_grid / 2. + 0.5, 0.0, 1.0)
                            data_grid = make_grid(X[:100], 10)
                            data_grid = torch.clamp(data_grid / 2. + 0.5, 0.0, 1.0)
                        elif self.config.data.dataset == 'MNIST':
                            samples_logits = decoder(z)
                            samples = torch.sigmoid(samples_logits)
                            samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                                   self.config.data.image_size)
                            image_grid = make_grid(samples, 10)
                            data_grid = make_grid(X[:100], 10)

                        tb_logger.add_image('samples', image_grid, global_step=step)
                        tb_logger.add_image('data', data_grid, global_step=step)

                        if len(validation_losses) != 0:
                            validation_loss = sum(validation_losses) / len(validation_losses)
                            if validation_loss < best_validation_loss:
                                best_validation_loss = validation_loss
                                validation_losses = []
                            # else:
                            #     return 0

                if (step + 1) % 10000 == 0:
                    if self.config.training.algo == 'vae':
                        states = [
                            encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict()
                        ]
                    elif self.config.training.algo == 'ssm':
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            score.state_dict(),
                            opt_ae.state_dict(),
                            opt_score.state_dict()
                        ]
                    elif self.config.training.algo in ['spectral', 'stein']:
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict()
                        ]
                    torch.save(states,
                               os.path.join(self.args.log, 'checkpoint_{}0k.pth'.format((step + 1) // 10000)))
                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))

                step += 1
                if step >= self.config.training.n_iters:
                    return 0