예제 #1
0
    def _create_vae_model(self):
        original_dim = self.all_data.shape[1]
        intermediate_dim = self.parameters['TechniqueParameters']['IntermediateDimension']
        latent_dim = self.parameters['TechniqueParameters']['LatentDimension']

        encoder = Encoder(intermediate_dim=intermediate_dim, latent_dim=latent_dim)
        encoder.build((None, original_dim))
        decoder = Decoder(intermediate_dim=intermediate_dim, original_dim=original_dim)
        decoder.build((None, latent_dim))
        model = VariationalAutoencoder(encoder, decoder, latent_dim=latent_dim)

        return model
예제 #2
0
    def test(self):
        states = torch.load(
            os.path.join(self.args.log, "checkpoint.pth"),
            map_location=self.config.device,
        )
        decoder = (MLPDecoder(self.config).to(self.config.device)
                   if self.config.data.dataset == "MNIST" else Decoder(
                       self.config).to(self.config.device))
        decoder.eval()
        decoder.load_state_dict(states[1])
        z = torch.randn(100,
                        self.config.model.z_dim,
                        device=self.config.device)
        if self.config.data.dataset == "CELEBA":
            samples, _ = decoder(z)
            samples = samples.view(
                100,
                self.config.data.channels,
                self.config.data.image_size,
                self.config.data.image_size,
            )
            image_grid = make_grid(samples, 10)
            image_grid = torch.clamp(image_grid / 2.0 + 0.5, 0.0, 1.0)
        elif self.config.data.dataset == "MNIST":
            samples_logits = decoder(z)
            samples = torch.sigmoid(samples_logits)
            samples = samples.view(
                100,
                self.config.data.channels,
                self.config.data.image_size,
                self.config.data.image_size,
            )
            image_grid = make_grid(samples, 10)

        save_image(image_grid, "image_grid.png")
예제 #3
0
    def __init__(self, dims):
        """
        M2 code replication from the paper
        'Semi-Supervised Learning with Deep Generative Models'
        (Kingma 2014) in PyTorch.

        The "Generative semi-supervised model" is a probabilistic
        model that incorporates label information in both
        inference and generation.

        Initialise a new generative model
        :param dims: dimensions of x, y, z and hidden layers.
        """
        [x_dim, self.y_dim, z_dim, h_dim] = dims
        super(DeepGenerativeModel, self).__init__([x_dim, z_dim, h_dim])

        self.encoder = Encoder([x_dim + self.y_dim, h_dim, z_dim])
        self.decoder = Decoder(
            [z_dim + self.y_dim,
             list(reversed(h_dim)), x_dim])
        self.classifier = Classifier([x_dim, h_dim[0], self.y_dim])

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
예제 #4
0
    def __init__(self, dims):
        """
        Auxiliary Deep Generative Models [Maaløe 2016]
        code replication. The ADGM introduces an additional
        latent variable 'a', which enables the model to fit
        more complex variational distributions.

        :param dims: dimensions of x, y, z, a and hidden layers.
        """
        [x_dim, y_dim, z_dim, a_dim, h_dim] = dims
        super(AuxiliaryDeepGenerativeModel,
              self).__init__([x_dim, y_dim, z_dim, h_dim])

        self.aux_encoder = Encoder([x_dim, h_dim, a_dim])  # q(a|x)
        self.aux_decoder = Encoder(
            [x_dim + z_dim + y_dim,
             list(reversed(h_dim)), a_dim])  # p(a|x,y,z)

        self.classifier = Classifier([x_dim + a_dim, h_dim[0],
                                      y_dim])  # q(y|a,x)

        self.encoder = Encoder([a_dim + y_dim + x_dim, h_dim,
                                z_dim])  # q(z|a,y,x)
        self.decoder = Decoder([y_dim + z_dim,
                                list(reversed(h_dim)), x_dim])  # p(x|y,z)
예제 #5
0
    def __init__(self, dims):
        """
        Ladder version of the Deep Generative Model.
        Uses a hierarchical representation that is
        trained end-to-end to give very nice disentangled
        representations.

        :param dims: dimensions of x, y, z layers and h layers
            note that len(z) == len(h).
        """
        [x_dim, y_dim, z_dim, h_dim] = dims
        super(LadderDeepGenerativeModel,
              self).__init__([x_dim, y_dim, z_dim[0], h_dim])

        neurons = [x_dim, *h_dim]
        encoder_layers = [
            LadderEncoder([neurons[i - 1], neurons[i], z_dim[i - 1]])
            for i in range(1, len(neurons))
        ]

        e = encoder_layers[-1]
        encoder_layers[-1] = LadderEncoder(
            [e.in_features + y_dim, e.out_features, e.z_dim])

        decoder_layers = [
            LadderDecoder([z_dim[i - 1], h_dim[i - 1], z_dim[i]])
            for i in range(1, len(h_dim))
        ][::-1]

        self.classifier = Classifier([x_dim, h_dim[0], y_dim])

        self.encoder = nn.ModuleList(encoder_layers)
        self.decoder = nn.ModuleList(decoder_layers)
        self.reconstruction = Decoder([z_dim[0] + y_dim, h_dim, x_dim])

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
예제 #6
0
    def train(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == "CIFAR10":
            dataset = CIFAR10(
                os.path.join(self.args.run, "datasets", "cifar10"),
                train=True,
                download=True,
                transform=transform,
            )
            test_dataset = CIFAR10(
                os.path.join(self.args.run, "datasets", "cifar10"),
                train=False,
                download=True,
                transform=transform,
            )
        elif self.config.data.dataset == "MNIST":
            dataset = MNIST(
                os.path.join(self.args.run, "datasets", "mnist"),
                train=True,
                download=True,
                transform=transform,
            )
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = (
                indices[:int(num_items * 0.8)],
                indices[int(num_items * 0.8):],
            )
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        elif self.config.data.dataset == "CELEBA":
            dataset = ImageFolder(
                # root="/raid/tianyu/ncsn/run/datasets/celeba/celeba",
                root="/home/kunxu/tmp/",
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
            )
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = (
                indices[:int(num_items * 0.7)],
                indices[int(num_items * 0.7):int(num_items * 0.8)],
            )
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        dataloader = DataLoader(
            dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            num_workers=4,
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=True,
            num_workers=2,
        )
        test_iter = iter(test_loader)
        self.config.input_dim = (self.config.data.image_size**2 *
                                 self.config.data.channels)

        tb_path = os.path.join(self.args.run, "tensorboard", self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        decoder = (MLPDecoder(self.config).to(self.config.device)
                   if self.config.data.dataset == "MNIST" else Decoder(
                       self.config).to(self.config.device))
        if self.config.training.algo == "vae":
            encoder = (MLPEncoder(self.config).to(self.config.device)
                       if self.config.data.dataset == "MNIST" else Encoder(
                           self.config).to(self.config.device))
            optimizer = self.get_optimizer(
                itertools.chain(encoder.parameters(), decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(
                    os.path.join(self.args.log, "checkpoint.pth"))
                encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])
        elif self.config.training.algo in ["ssm", "ssm_fd"]:
            score = (MLPScore(self.config).to(self.config.device)
                     if self.config.data.dataset == "MNIST" else Score(
                         self.config).to(self.config.device))
            imp_encoder = (MLPImplicitEncoder(self.config).to(
                self.config.device) if self.config.data.dataset == "MNIST" else
                           ImplicitEncoder(self.config).to(self.config.device))

            opt_ae = optim.RMSprop(
                itertools.chain(decoder.parameters(),
                                imp_encoder.parameters()),
                lr=self.config.optim.lr,
            )
            opt_score = optim.RMSprop(score.parameters(),
                                      lr=self.config.optim.lr)
            if self.args.resume_training:
                states = torch.load(
                    os.path.join(self.args.log, "checkpoint.pth"))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                score.load_state_dict(states[2])
                opt_ae.load_state_dict(states[3])
                opt_score.load_state_dict(states[4])
        elif self.config.training.algo in ["spectral", "stein"]:
            from models.kernel_score_estimators import (
                SpectralScoreEstimator,
                SteinScoreEstimator,
            )

            imp_encoder = (MLPImplicitEncoder(self.config).to(
                self.config.device) if self.config.data.dataset == "MNIST" else
                           ImplicitEncoder(self.config).to(self.config.device))
            estimator = (SpectralScoreEstimator() if self.config.training.algo
                         == "spectral" else SteinScoreEstimator())
            optimizer = self.get_optimizer(
                itertools.chain(imp_encoder.parameters(),
                                decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(
                    os.path.join(self.args.log, "checkpoint.pth"))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])

        step = 0
        best_validation_loss = np.inf
        validation_losses = []
        recon_type = "bernoulli" if self.config.data.dataset == "MNIST" else "gaussian"
        time_dur = 0.0
        for _ in range(self.config.training.n_epochs):
            for _, (X, y) in enumerate(dataloader):
                decoder.train()
                X = X.to(self.config.device)
                if self.config.data.dataset == "CELEBA":
                    X = X + (torch.rand_like(X) - 0.5) / 128.0
                elif self.config.data.dataset == "MNIST":
                    eps = torch.rand_like(X)
                    X = (eps <= X).float()

                if self.config.training.algo == "vae":
                    encoder.train()
                    loss, *_ = elbo(encoder, decoder, X, recon_type)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                elif self.config.training.algo == "ssm":
                    imp_encoder.train()
                    loss, run_time, ssm_loss, *_ = elbo_ssm(
                        imp_encoder,
                        decoder,
                        score,
                        opt_score,
                        X,
                        recon_type,
                        training=True,
                        n_particles=self.config.model.n_particles,
                    )
                    opt_ae.zero_grad()
                    loss.backward()
                    opt_ae.step()
                elif self.config.training.algo == "ssm_fd":
                    imp_encoder.train()
                    loss, run_time, ssm_loss, *_ = elbo_ssm_fd(
                        imp_encoder,
                        decoder,
                        score,
                        opt_score,
                        X,
                        recon_type,
                        training=True,
                        n_particles=self.config.model.n_particles,
                    )
                    opt_ae.zero_grad()
                    loss.backward()
                    opt_ae.step()
                elif self.config.training.algo in ["spectral", "stein"]:
                    imp_encoder.train()
                    loss = elbo_kernel(
                        imp_encoder,
                        decoder,
                        estimator,
                        X,
                        recon_type,
                        n_particles=self.config.model.n_particles,
                    )
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                time_dur += run_time

                if step % 10 == 0:
                    try:
                        test_X, _ = next(test_iter)
                    except:
                        test_iter = iter(test_loader)
                        test_X, _ = next(test_iter)

                    test_X = test_X.to(self.config.device)
                    if self.config.data.dataset == "CELEBA":
                        test_X = test_X + (torch.rand_like(test_X) -
                                           0.5) / 128.0
                    elif self.config.data.dataset == "MNIST":
                        test_eps = torch.rand_like(test_X)
                        test_X = (test_eps <= test_X).float()

                    decoder.eval()
                    if self.config.training.algo == "vae":
                        encoder.eval()
                        with torch.no_grad():
                            test_loss, *_ = elbo(encoder, decoder, test_X,
                                                 recon_type)
                            logging.info("loss: {}, test_loss: {}".format(
                                loss.item(), test_loss.item()))
                    elif self.config.training.algo == "ssm":
                        imp_encoder.eval()
                        test_loss, *_ = elbo_ssm(
                            imp_encoder,
                            decoder,
                            score,
                            None,
                            test_X,
                            recon_type,
                            training=False,
                        )
                        logging.info(
                            "loss: {}, ssm_loss: {}, test_loss: {}".format(
                                loss.item(), ssm_loss.item(),
                                test_loss.item()))
                        z = imp_encoder(test_X)
                        tb_logger.add_histogram("z_X", z, global_step=step)
                    elif self.config.training.algo == "ssm_fd":
                        imp_encoder.eval()
                        test_loss, *_ = elbo_ssm_fd(
                            imp_encoder,
                            decoder,
                            score,
                            None,
                            test_X,
                            recon_type,
                            training=False,
                        )
                        logging.info(
                            "loss: {}, ssm_loss: {}, test_loss: {}".format(
                                loss.item(), ssm_loss.item(),
                                test_loss.item()))
                        z = imp_encoder(test_X)
                        tb_logger.add_histogram("z_X", z, global_step=step)
                    elif self.config.training.algo in ["spectral", "stein"]:
                        imp_encoder.eval()
                        with torch.no_grad():
                            test_loss = elbo_kernel(imp_encoder, decoder,
                                                    estimator, test_X,
                                                    recon_type, 10)

                            logging.info("loss: {}, test_loss: {}".format(
                                loss.item(), test_loss.item()))

                    validation_losses.append(test_loss.item())
                    tb_logger.add_scalar("loss", loss, global_step=step)
                    tb_logger.add_scalar("test_loss",
                                         test_loss,
                                         global_step=step)

                    if self.config.training.algo in ["ssm", "ssm_fd"]:
                        tb_logger.add_scalar("ssm_loss",
                                             ssm_loss,
                                             global_step=step)

                if step % 500 == 0:
                    logging.info(
                        "Time Dur in this 500 iters: {}".format(time_dur))
                    time_dur = 0.0
                    with torch.no_grad():
                        z = torch.randn(100,
                                        self.config.model.z_dim,
                                        device=X.device)
                        decoder.eval()
                        if self.config.data.dataset == "CELEBA":
                            samples, _ = decoder(z)
                            samples = samples.view(
                                100,
                                self.config.data.channels,
                                self.config.data.image_size,
                                self.config.data.image_size,
                            )
                            image_grid = make_grid(samples, 10)
                            image_grid = torch.clamp(image_grid / 2.0 + 0.5,
                                                     0.0, 1.0)
                            data_grid = make_grid(X[:100], 10)
                            data_grid = torch.clamp(data_grid / 2.0 + 0.5, 0.0,
                                                    1.0)
                        elif self.config.data.dataset == "MNIST":
                            samples_logits = decoder(z)
                            samples = torch.sigmoid(samples_logits)
                            samples = samples.view(
                                100,
                                self.config.data.channels,
                                self.config.data.image_size,
                                self.config.data.image_size,
                            )
                            image_grid = make_grid(samples, 10)
                            data_grid = make_grid(X[:100], 10)

                        tb_logger.add_image("samples",
                                            image_grid,
                                            global_step=step)
                        tb_logger.add_image("data",
                                            data_grid,
                                            global_step=step)

                        if len(validation_losses) != 0:
                            validation_loss = sum(validation_losses) / len(
                                validation_losses)
                            if validation_loss < best_validation_loss:
                                best_validation_loss = validation_loss
                                validation_losses = []
                            # else:
                            #     return 0

                if (step + 1) % 10000 == 0:
                    if self.config.training.algo == "vae":
                        states = [
                            encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict(),
                        ]
                    elif self.config.training.algo in ["ssm", "ssm_fd"]:
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            score.state_dict(),
                            opt_ae.state_dict(),
                            opt_score.state_dict(),
                        ]
                    elif self.config.training.algo in ["spectral", "stein"]:
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict(),
                        ]
                    torch.save(
                        states,
                        os.path.join(
                            self.args.log,
                            "checkpoint_{}0k.pth".format((step + 1) // 10000),
                        ),
                    )
                    torch.save(states,
                               os.path.join(self.args.log, "checkpoint.pth"))

                step += 1
                if step >= self.config.training.n_iters:
                    return 0
예제 #7
0
    def test_fid(self):
        assert self.config.data.dataset == "CELEBA"
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == "CIFAR10":
            test_dataset = CIFAR10(
                os.path.join(self.args.run, "datasets", "cifar10"),
                train=False,
                download=True,
                transform=transform,
            )
        elif self.config.data.dataset == "MNIST":
            test_dataset = MNIST(
                os.path.join(self.args.run, "datasets", "mnist"),
                train=False,
                download=True,
                transform=transform,
            )
        elif self.config.data.dataset == "CELEBA":
            dataset = ImageFolder(
                # root="/raid/tianyu/ncsn/run/datasets/celeba/celeba",
                root="/home/kunxu/tmp/",
                transform=transforms.Compose([
                    transforms.CenterCrop(140),
                    transforms.Resize(self.config.data.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
            )
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            test_indices = indices[int(0.8 * num_items):]
            test_dataset = Subset(dataset, test_indices)

        test_loader = DataLoader(
            test_dataset,
            batch_size=self.config.training.batch_size,
            shuffle=False,
            num_workers=2,
        )

        self.config.input_dim = (self.config.data.image_size**2 *
                                 self.config.data.channels)

        get_data_stats = False
        manual = False
        if get_data_stats:
            data_images = []
            for _, (X, y) in enumerate(test_loader):
                X = X.to(self.config.device)
                X = X + (torch.rand_like(X) - 0.5) / 128.0
                data_images.extend(X / 2.0 + 0.5)
                if len(data_images) > 10000:
                    break

            if not os.path.exists(
                    os.path.join(self.args.run, "datasets", "celeba140_fid",
                                 "raw_images")):
                os.makedirs(
                    os.path.join(self.args.run, "datasets", "celeba140_fid",
                                 "raw_images"))
            logging.info("Saving data images")
            for i, image in enumerate(data_images):
                save_image(
                    image,
                    os.path.join(
                        self.args.run,
                        "datasets",
                        "celeba140_fid",
                        "raw_images",
                        "{}.png".format(i),
                    ),
                )
            logging.info("Images saved. Calculating fid statistics now")
            fid.calculate_data_statics(
                os.path.join(self.args.run, "datasets", "celeba140_fid",
                             "raw_images"),
                os.path.join(self.args.run, "datasets", "celeba140_fid"),
                50,
                True,
                2048,
            )

        else:
            if manual:
                states = torch.load(
                    os.path.join(self.args.log, "checkpoint_100k.pth"),
                    map_location=self.config.device,
                )
                decoder = Decoder(self.config).to(self.config.device)
                decoder.eval()
                if self.config.training.algo == "vae":
                    encoder = Encoder(self.config).to(self.config.device)
                    encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])
                elif self.config.training.algo in ["ssm", "ssm_fd"]:
                    score = Score(self.config).to(self.config.device)
                    imp_encoder = ImplicitEncoder(self.config).to(
                        self.config.device)
                    imp_encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])
                    score.load_state_dict(states[2])
                elif self.config.training.algo in ["spectral", "stein"]:
                    from models.kernel_score_estimators import (
                        SpectralScoreEstimator,
                        SteinScoreEstimator,
                    )

                    imp_encoder = ImplicitEncoder(self.config).to(
                        self.config.device)
                    imp_encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])

                all_samples = []
                logging.info("Generating samples")
                for i in range(100):
                    with torch.no_grad():
                        z = torch.randn(100,
                                        self.config.model.z_dim,
                                        device=self.config.device)
                        samples, _ = decoder(z)
                        samples = samples.view(
                            100,
                            self.config.data.channels,
                            self.config.data.image_size,
                            self.config.data.image_size,
                        )
                        all_samples.extend(samples / 2.0 + 0.5)

                if not os.path.exists(
                        os.path.join(self.args.log, "samples", "raw_images")):
                    os.makedirs(
                        os.path.join(self.args.log, "samples", "raw_images"))
                logging.info("Images generated. Saving images")
                for i, image in enumerate(all_samples):
                    save_image(
                        image,
                        os.path.join(self.args.log, "samples", "raw_images",
                                     "{}.png".format(i)),
                    )
                logging.info("Generating fid statistics")
                fid.calculate_data_statics(
                    os.path.join(self.args.log, "samples", "raw_images"),
                    os.path.join(self.args.log, "samples"),
                    50,
                    True,
                    2048,
                )
                logging.info("Statistics generated.")
            else:
                for iter in range(10, 11):
                    states = torch.load(
                        os.path.join(self.args.log,
                                     "checkpoint_{}0k.pth".format(iter)),
                        map_location=self.config.device,
                    )
                    decoder = Decoder(self.config).to(self.config.device)
                    decoder.eval()
                    if self.config.training.algo == "vae":
                        encoder = Encoder(self.config).to(self.config.device)
                        encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])
                    elif self.config.training.algo in ["ssm", "ssm_fd"]:
                        score = Score(self.config).to(self.config.device)
                        imp_encoder = ImplicitEncoder(self.config).to(
                            self.config.device)
                        imp_encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])
                        score.load_state_dict(states[2])
                    elif self.config.training.algo in ["spectral", "stein"]:
                        from models.kernel_score_estimators import (
                            SpectralScoreEstimator,
                            SteinScoreEstimator,
                        )

                        imp_encoder = ImplicitEncoder(self.config).to(
                            self.config.device)
                        imp_encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])

                    all_samples = []
                    logging.info("Generating samples")
                    for i in range(100):
                        with torch.no_grad():
                            z = torch.randn(100,
                                            self.config.model.z_dim,
                                            device=self.config.device)
                            samples, _ = decoder(z)
                            samples = samples.view(
                                100,
                                self.config.data.channels,
                                self.config.data.image_size,
                                self.config.data.image_size,
                            )
                            all_samples.extend(samples / 2.0 + 0.5)

                    if not os.path.exists(
                            os.path.join(self.args.log, "samples",
                                         "raw_images_{}0k".format(iter))):
                        os.makedirs(
                            os.path.join(self.args.log, "samples",
                                         "raw_images_{}0k".format(iter)))
                    else:
                        shutil.rmtree(
                            os.path.join(self.args.log, "samples",
                                         "raw_images_{}0k".format(iter)))
                        os.makedirs(
                            os.path.join(self.args.log, "samples",
                                         "raw_images_{}0k".format(iter)))

                    if not os.path.exists(
                            os.path.join(self.args.log, "samples",
                                         "statistics_{}0k".format(iter))):
                        os.makedirs(
                            os.path.join(self.args.log, "samples",
                                         "statistics_{}0k".format(iter)))
                    else:
                        shutil.rmtree(
                            os.path.join(self.args.log, "samples",
                                         "statistics_{}0k".format(iter)))
                        os.makedirs(
                            os.path.join(self.args.log, "samples",
                                         "statistics_{}0k".format(iter)))

                    logging.info("Images generated. Saving images")
                    for i, image in enumerate(all_samples):
                        save_image(
                            image,
                            os.path.join(
                                self.args.log,
                                "samples",
                                "raw_images_{}0k".format(iter),
                                "{}.png".format(i),
                            ),
                        )
                    logging.info("Generating fid statistics")
                    fid.calculate_data_statics(
                        os.path.join(self.args.log, "samples",
                                     "raw_images_{}0k".format(iter)),
                        os.path.join(self.args.log, "samples",
                                     "statistics_{}0k".format(iter)),
                        50,
                        True,
                        2048,
                    )
                    logging.info("Statistics generated.")
                    fid_number = fid.calculate_fid_given_paths(
                        [
                            "run/datasets/celeba140_fid/celeba_test.npz",
                            os.path.join(
                                self.args.log,
                                "samples",
                                "statistics_{}0k".format(iter),
                                "celeba_test.npz",
                            ),
                        ],
                        50,
                        True,
                        2048,
                    )
                    logging.info("Number of iters: {}0k, FID: {}".format(
                        iter, fid_number))
예제 #8
0
def main(data_directory, num_epochs, batch_size):
    dataset = get_audio_dataset(data_directory,
                                max_length_in_seconds=2,
                                pad_and_truncate=True)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=8)

    train_dataloader_len = len(dataloader)

    encoder = Encoder(64, 1, 100).to("cuda")
    decoder = Decoder(64, 1, 100).to("cuda")

    siamese = Encoder(64, 1, 100)
    siamese_main = siamese.main.to("cuda")
    siamese_output = siamese.mu.to("cuda")

    optimizer_encoder = torch.optim.Adam(encoder.parameters(), lr=1e-4)
    optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=1e-4)
    optimizer_siamese = torch.optim.Adam(siamese.parameters(), lr=1e-4)

    criterion = torch.nn.MSELoss()

    for epoch in range(num_epochs):
        for sample_idx, (audio, _) in enumerate(dataloader):
            batch_size = audio.size(0)

            decoder.zero_grad()
            encoder.zero_grad()
            siamese.zero_grad()

            audio = audio.to("cuda")

            z, mu, logvar = encoder(audio)
            decoded = decoder(z)
            decoded = decoded.narrow(2, 0, 32000)

            hidden_fake_main = siamese_main(decoded)
            hidden_fake_output = siamese_output(hidden_fake_main)

            hidden_real_main = siamese_main(audio)
            hidden_real_output = siamese_output(hidden_real_main)

            err = criterion(hidden_fake_output, hidden_real_output)

            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

            loss = err + KLD
            loss.backward()

            optimizer_encoder.step()
            optimizer_decoder.step()
            optimizer_siamese.step()

            print(
                f"{epoch:06d}-[{sample_idx + 1}/{train_dataloader_len}]: loss {loss.mean().item()}"
            )

            if sample_idx % 100 == 0:
                with torch.no_grad():
                    fake_noise = torch.randn(1, 100, 1).to("cuda")
                    output_gen = decoder(fake_noise).narrow(2, 0,
                                                            32000).to("cpu")
                    torchaudio.save(
                        f"outputs/decoder_output_{epoch:06d}_{sample_idx:06d}.wav",
                        output_gen[0],
                        16000,
                    )
        torch.save(encoder.state_dict(),
                   "%s/encoder_epoch_%d.pth" % ("checkpoints", epoch))
        torch.save(decoder.state_dict(),
                   "%s/netD_epoch_%d.pth" % ("checkpoints", epoch))
        torch.save(siamese.state_dict(),
                   "%s/siamese_epoch_%d.pth" % ("checkpoints", epoch))
예제 #9
0
    def train(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True,
                              transform=transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True,
                            transform=transform)
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = indices[:int(num_items * 0.8)], indices[int(num_items * 0.8):]
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        elif self.config.data.dataset == 'CELEBA':
            dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'),
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = indices[:int(num_items * 0.7)], indices[
                                                                          int(num_items * 0.7):int(num_items * 0.8)]
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=2)
        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        decoder = MLPDecoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
            else Decoder(self.config).to(self.config.device)
        if self.config.training.algo == 'vae':
            encoder = MLPEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else Encoder(self.config).to(self.config.device)
            optimizer = self.get_optimizer(itertools.chain(encoder.parameters(), decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])
        elif self.config.training.algo == 'ssm':
            score = MLPScore(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' else \
                Score(self.config).to(self.config.device)
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else ImplicitEncoder(self.config).to(self.config.device)

            opt_ae = optim.RMSprop(itertools.chain(decoder.parameters(), imp_encoder.parameters()),
                                   lr=self.config.optim.lr)
            opt_score = optim.RMSprop(score.parameters(), lr=self.config.optim.lr)
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                score.load_state_dict(states[2])
                opt_ae.load_state_dict(states[3])
                opt_score.load_state_dict(states[4])
        elif self.config.training.algo in ['spectral', 'stein']:
            from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else ImplicitEncoder(self.config).to(self.config.device)
            estimator = SpectralScoreEstimator() if self.config.training.algo == 'spectral' else SteinScoreEstimator()
            optimizer = self.get_optimizer(itertools.chain(imp_encoder.parameters(), decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])

        step = 0
        best_validation_loss = np.inf
        validation_losses = []
        recon_type = 'bernoulli' if self.config.data.dataset == 'MNIST' else 'gaussian'

        for _ in range(self.config.training.n_epochs):
            for _, (X, y) in enumerate(dataloader):
                decoder.train()
                X = X.to(self.config.device)
                if self.config.data.dataset == 'CELEBA':
                    X = X + (torch.rand_like(X) - 0.5) / 128.
                elif self.config.data.dataset == 'MNIST':
                    eps = torch.rand_like(X)
                    X = (eps <= X).float()

                if self.config.training.algo == 'vae':
                    encoder.train()
                    loss, *_ = elbo(encoder, decoder, X, recon_type)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                elif self.config.training.algo == 'ssm':
                    imp_encoder.train()
                    loss, ssm_loss, *_ = elbo_ssm(imp_encoder, decoder, score, opt_score, X, recon_type,
                                                  training=True, n_particles=self.config.model.n_particles)
                    opt_ae.zero_grad()
                    loss.backward()
                    opt_ae.step()
                elif self.config.training.algo in ['spectral', 'stein']:
                    imp_encoder.train()
                    loss = elbo_kernel(imp_encoder, decoder, estimator, X, recon_type,
                                       n_particles=self.config.model.n_particles)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                if step % 10 == 0:
                    try:
                        test_X, _ = next(test_iter)
                    except:
                        test_iter = iter(test_loader)
                        test_X, _ = next(test_iter)

                    test_X = test_X.to(self.config.device)
                    if self.config.data.dataset == 'CELEBA':
                        test_X = test_X + (torch.rand_like(test_X) - 0.5) / 128.
                    elif self.config.data.dataset == 'MNIST':
                        test_eps = torch.rand_like(test_X)
                        test_X = (test_eps <= test_X).float()

                    decoder.eval()
                    if self.config.training.algo == 'vae':
                        encoder.eval()
                        with torch.no_grad():
                            test_loss, *_ = elbo(encoder, decoder, test_X, recon_type)
                            logging.info("loss: {}, test_loss: {}".format(loss.item(), test_loss.item()))
                    elif self.config.training.algo == 'ssm':
                        imp_encoder.eval()
                        test_loss, *_ = elbo_ssm(imp_encoder, decoder, score, None, test_X, recon_type, training=False)
                        logging.info("loss: {}, ssm_loss: {}, test_loss: {}".format(loss.item(), ssm_loss.item(),
                                                                                    test_loss.item()))
                        z = imp_encoder(test_X)
                        tb_logger.add_histogram('z_X', z, global_step=step)
                    elif self.config.training.algo in ['spectral', 'stein']:
                        imp_encoder.eval()
                        with torch.no_grad():
                            test_loss = elbo_kernel(imp_encoder, decoder, estimator, test_X, recon_type, 10)

                            logging.info("loss: {}, test_loss: {}".format(loss.item(), test_loss.item()))

                    validation_losses.append(test_loss.item())
                    tb_logger.add_scalar('loss', loss, global_step=step)
                    tb_logger.add_scalar('test_loss', test_loss, global_step=step)

                    if self.config.training.algo == 'ssm':
                        tb_logger.add_scalar('ssm_loss', ssm_loss, global_step=step)

                if step % 500 == 0:
                    with torch.no_grad():
                        z = torch.randn(100, self.config.model.z_dim, device=X.device)
                        decoder.eval()
                        if self.config.data.dataset == 'CELEBA':
                            samples, _ = decoder(z)
                            samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                                   self.config.data.image_size)
                            image_grid = make_grid(samples, 10)
                            image_grid = torch.clamp(image_grid / 2. + 0.5, 0.0, 1.0)
                            data_grid = make_grid(X[:100], 10)
                            data_grid = torch.clamp(data_grid / 2. + 0.5, 0.0, 1.0)
                        elif self.config.data.dataset == 'MNIST':
                            samples_logits = decoder(z)
                            samples = torch.sigmoid(samples_logits)
                            samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                                   self.config.data.image_size)
                            image_grid = make_grid(samples, 10)
                            data_grid = make_grid(X[:100], 10)

                        tb_logger.add_image('samples', image_grid, global_step=step)
                        tb_logger.add_image('data', data_grid, global_step=step)

                        if len(validation_losses) != 0:
                            validation_loss = sum(validation_losses) / len(validation_losses)
                            if validation_loss < best_validation_loss:
                                best_validation_loss = validation_loss
                                validation_losses = []
                            # else:
                            #     return 0

                if (step + 1) % 10000 == 0:
                    if self.config.training.algo == 'vae':
                        states = [
                            encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict()
                        ]
                    elif self.config.training.algo == 'ssm':
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            score.state_dict(),
                            opt_ae.state_dict(),
                            opt_score.state_dict()
                        ]
                    elif self.config.training.algo in ['spectral', 'stein']:
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict()
                        ]
                    torch.save(states,
                               os.path.join(self.args.log, 'checkpoint_{}0k.pth'.format((step + 1) // 10000)))
                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))

                step += 1
                if step >= self.config.training.n_iters:
                    return 0
예제 #10
0
    def test_fid(self):
        assert self.config.data.dataset == 'CELEBA'
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True,
                                 transform=transform)
        elif self.config.data.dataset == 'CELEBA':
            dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'),
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            test_indices = indices[int(0.8 * num_items):]
            test_dataset = Subset(dataset, test_indices)

        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=False,
                                 num_workers=2)

        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        get_data_stats = False
        manual = False
        if get_data_stats:
            data_images = []
            for _, (X, y) in enumerate(test_loader):
                X = X.to(self.config.device)
                X = X + (torch.rand_like(X) - 0.5) / 128.
                data_images.extend(X / 2. + 0.5)
                if len(data_images) > 10000:
                    break

            if not os.path.exists(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images')):
                os.makedirs(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images'))
            logging.info("Saving data images")
            for i, image in enumerate(data_images):
                save_image(image,
                           os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images', '{}.png'.format(i)))
            logging.info("Images saved. Calculating fid statistics now")
            fid.calculate_data_statics(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images'),
                                       os.path.join(self.args.run, 'datasets', 'celeba140_fid'), 50, True, 2048)


        else:
            if manual:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
                decoder = Decoder(self.config).to(self.config.device)
                decoder.eval()
                if self.config.training.algo == 'vae':
                    encoder = Encoder(self.config).to(self.config.device)
                    encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])
                elif self.config.training.algo == 'ssm':
                    score = Score(self.config).to(self.config.device)
                    imp_encoder = ImplicitEncoder(self.config).to(self.config.device)
                    imp_encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])
                    score.load_state_dict(states[2])
                elif self.config.training.algo in ['spectral', 'stein']:
                    from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator
                    imp_encoder = ImplicitEncoder(self.config).to(self.config.device)
                    imp_encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])

                all_samples = []
                logging.info("Generating samples")
                for i in range(100):
                    with torch.no_grad():
                        z = torch.randn(100, self.config.model.z_dim, device=self.config.device)
                        samples, _ = decoder(z)
                        samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                               self.config.data.image_size)
                        all_samples.extend(samples / 2. + 0.5)

                if not os.path.exists(os.path.join(self.args.log, 'samples', 'raw_images')):
                    os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images'))
                logging.info("Images generated. Saving images")
                for i, image in enumerate(all_samples):
                    save_image(image, os.path.join(self.args.log, 'samples', 'raw_images', '{}.png'.format(i)))
                logging.info("Generating fid statistics")
                fid.calculate_data_statics(os.path.join(self.args.log, 'samples', 'raw_images'),
                                           os.path.join(self.args.log, 'samples'), 50, True, 2048)
                logging.info("Statistics generated.")
            else:
                for iter in range(1, 11):
                    states = torch.load(os.path.join(self.args.log, 'checkpoint_{}0k.pth'.format(iter)),
                                        map_location=self.config.device)
                    decoder = Decoder(self.config).to(self.config.device)
                    decoder.eval()
                    if self.config.training.algo == 'vae':
                        encoder = Encoder(self.config).to(self.config.device)
                        encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])
                    elif self.config.training.algo == 'ssm':
                        score = Score(self.config).to(self.config.device)
                        imp_encoder = ImplicitEncoder(self.config).to(self.config.device)
                        imp_encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])
                        score.load_state_dict(states[2])
                    elif self.config.training.algo in ['spectral', 'stein']:
                        from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator
                        imp_encoder = ImplicitEncoder(self.config).to(self.config.device)
                        imp_encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])

                    all_samples = []
                    logging.info("Generating samples")
                    for i in range(100):
                        with torch.no_grad():
                            z = torch.randn(100, self.config.model.z_dim, device=self.config.device)
                            samples, _ = decoder(z)
                            samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                                   self.config.data.image_size)
                            all_samples.extend(samples / 2. + 0.5)

                    if not os.path.exists(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))):
                        os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)))
                    else:
                        shutil.rmtree(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)))
                        os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)))

                    if not os.path.exists(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))):
                        os.makedirs(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)))
                    else:
                        shutil.rmtree(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)))
                        os.makedirs(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)))

                    logging.info("Images generated. Saving images")
                    for i, image in enumerate(all_samples):
                        save_image(image, os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter),
                                                       '{}.png'.format(i)))
                    logging.info("Generating fid statistics")
                    fid.calculate_data_statics(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)),
                                               os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)),
                                               50, True, 2048)
                    logging.info("Statistics generated.")
                    fid_number = fid.calculate_fid_given_paths([
                        'run/datasets/celeba140_fid/celeba_test.npz',
                        os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter), 'celeba_test.npz')]
                        , 50, True, 2048)
                    logging.info("Number of iters: {}0k, FID: {}".format(iter, fid_number))