def __init__(self, config: GANConfig, num_workers: int, improved=False):
        super().__init__()
        self.save_hyperparameters()
        self.config = config
        self.num_workers = num_workers
        self.improved = improved

        generator_config = self.config.generator_config
        discriminator_config = self.config.discriminator_config
        if self.improved:
            self.generator = SN_Generator(**asdict(generator_config))
            self.discriminator = SN_Discriminator(**asdict(discriminator_config))
        else:
            self.generator = Generator(**asdict(generator_config))
            self.discriminator = Discriminator(**asdict(discriminator_config))
def get_models(config):
    disc = Discriminator(config['channels'], config['load_size'],
                         config['num_classes'], 64).to(config['device'])
    gen = Generator(config['z_dim'], config['channels'], config['load_size'],
                    config['num_classes'],
                    config['embed_size']).to(config['device'])
    initialise_weights(gen)
    initialise_weights(disc)

    return gen, disc
Exemple #3
0
train_loader_1 = DataLoader(mnist_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
train_loader_2 = DataLoader(kannada_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
val_loader = DataLoader(kannada_val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

num_batch = min(len(train_loader_1), len(train_loader_2))

dim_z = 512

# Modules
shared_latent = GaussianVAE2D(dim_z, dim_z, kernel_size=1, stride=1).to(device)
encoder_1 = Encoder().to(device)
encoder_2 = Encoder().to(device)
decoder_1 = Decoder().to(device)
decoder_2 = Decoder().to(device)
discriminator_1 = Discriminator().to(device)
discriminator_2 = Discriminator().to(device)

criterion_discr = nn.BCELoss(reduction='sum')
criterion_class = nn.CrossEntropyLoss(reduction='sum')

fixed_noise = torch.randn(args.batch_size, dim_z, 1, 1, device=device)
real_label = 1
fake_label = 0
label_noise = 0.1

# setup optimizer
dis_params = list(discriminator_1.parameters()) + list(discriminator_2.parameters())
optimizerD = optim.AdamW(dis_params, lr=2e-4, betas=(0.5, 0.999), weight_decay=5e-4)

gen_params = list(shared_latent.parameters()) + list(encoder_1.parameters()) + list(encoder_2.parameters()) + list(decoder_1.parameters()) + list(decoder_2.parameters())
Exemple #4
0
# val_dataset = CustomTensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val).long())
# test_dataset = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test).long())

train_loader = DataLoader(train_dataset,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=args.num_workers)
# val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
# test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

num_batch = len(train_loader)
num_epoch = args.iters // num_batch + 1
args.__dict__.update({'epochs': num_epoch})

netG = Decoder().to(device)
netD = Discriminator().to(device)

criterion = nn.BCELoss()

fixed_noise = torch.randn(args.batch_size, 512, 1, 1, device=device)
real_label = 1 - 0.1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.999))

for epoch in range(1, args.epochs + 1):
    for i, data in enumerate(train_loader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
Exemple #5
0
                          batch_size=args.batch_size,
                          shuffle=False,
                          num_workers=args.num_workers)
kannada_loader = DataLoader(kannada_train_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)

# Modules
dim_z = 512
shared_latent = GaussianVAE2D(dim_z, dim_z, kernel_size=1, stride=1).to(device)
encoder_1 = Encoder().to(device)
encoder_2 = Encoder().to(device)
decoder_1 = Decoder().to(device)
decoder_2 = Decoder().to(device)
discriminator_1 = Discriminator().to(device)
discriminator_2 = Discriminator().to(device)

if os.path.isfile(args.model):
    print("===> Loading Checkpoint to Evaluate '{}'".format(args.model))
    checkpoint = torch.load(args.model)
    shared_latent.load_state_dict(checkpoint['shared_latent'])
    encoder_1.load_state_dict(checkpoint['encoder_1'])
    encoder_2.load_state_dict(checkpoint['encoder_2'])
    decoder_1.load_state_dict(checkpoint['decoder_1'])
    decoder_2.load_state_dict(checkpoint['decoder_2'])
    discriminator_1.load_state_dict(checkpoint['discriminator_1'])
    discriminator_2.load_state_dict(checkpoint['discriminator_2'])
    print("\t===> Loaded Checkpoint '{}' (epoch {})".format(
        args.model, checkpoint['epoch']))
else:
Exemple #6
0
print("CUDA =", torch.cuda.is_available())

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, ), (0.5, ))
])

train_data = get_dataset(config, transforms=transforms)

gen = Generator(config['z_dim'],
                config['channels'],
                img_dim=config['img_dim'],
                embed_size=config['embed_size'],
                num_classes=config['num_classes']).to(config['device'])
disc = Discriminator(config['channels'],
                     num_classes=config['num_classes'],
                     img_size=config['img_dim']).to(config['device'])

loss_function = torch.nn.BCELoss()

optimiser_G = torch.optim.Adam(gen.parameters(),
                               lr=config['lr'],
                               betas=(0.0, 0.9))
optimiser_D = torch.optim.Adam(disc.parameters(),
                               lr=config['lr'],
                               betas=(0.0, 0.9))

step = 0
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
print("Training using", config['device'])

transformations = transforms.Compose([
    transforms.RandomCrop(256),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train = get_image_data_loader(os.path.join("datasets", "img2ai", "trainA"),
                              os.path.join("datasets", "img2ai", "trainB"),
                              transformations, transformations, batch=config['batch'])

G = Generator(config['channels']).to(config['device'])
F = Generator(config['channels']).to(config['device'])
Dx = Discriminator(config['channels']).to(config['device'])
Dy = Discriminator(config['channels']).to(config['device'])

try:
    G.load_state_dict(torch.load(os.path.join(config['checkpoint_path'], f'G_{config["load_state"]}.pth')))
    F.load_state_dict(torch.load(os.path.join(config['checkpoint_path'], f'F_{config["load_state"]}.pth')))
    Dx.load_state_dict(torch.load(os.path.join(config['checkpoint_path'], f'Dx_{config["load_state"]}.pth')))
    Dy.load_state_dict(torch.load(os.path.join(config['checkpoint_path'], f'Dy_{config["load_state"]}.pth')))
    print("Loaded weights from checkpoints.")
except:
    print("No checkpoints found, initialised new weights.")

identity_loss = torch.nn.MSELoss()

step = 0
class GAN(LightningModule):
    def __init__(self, config: GANConfig, num_workers: int, improved=False):
        super().__init__()
        self.save_hyperparameters()
        self.config = config
        self.num_workers = num_workers
        self.improved = improved

        generator_config = self.config.generator_config
        discriminator_config = self.config.discriminator_config
        if self.improved:
            self.generator = SN_Generator(**asdict(generator_config))
            self.discriminator = SN_Discriminator(**asdict(discriminator_config))
        else:
            self.generator = Generator(**asdict(generator_config))
            self.discriminator = Discriminator(**asdict(discriminator_config))

    def forward(self, z) -> torch.Tensor:
        return self.generator(z)

    def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]:
        g_optimizer = Adam(self.generator.parameters(), lr=self.config.g_lr, betas=self.config.g_betas)
        d_optimizer = Adam(self.discriminator.parameters(), lr=self.config.d_lr, betas=self.config.d_betas)
        return [g_optimizer] + self.config.n_cr * [d_optimizer], []

    def generator_loss(self, real_data: torch.Tensor) -> torch.Tensor:
        """
        Wassersten loss for generator
        """
        g_input = torch.randn(real_data.shape[0], self.config.generator_config.input_size).to(self.device)
        fake_data = self.generator(g_input)
        return -self.discriminator(fake_data).mean()

    def discriminator_loss(self, real_data: torch.Tensor) -> torch.Tensor:
        """
        Wassersten loss for discriminator
        """
        real_data = real_data.to(self.device)
        input_noise = torch.randn(real_data.shape[0], self.config.generator_config.input_size, device=self.device)
        fake_data = self.generator(input_noise).detach()
        gradient_penalty = self._gradient_penalty(real_data, fake_data) if self.improved else 0
        return self.discriminator(fake_data).mean() - self.discriminator(real_data).mean() + gradient_penalty

    def _gradient_penalty(self, real: torch.Tensor, fake: torch.Tensor) -> torch.Tensor:
        batch_size = real.shape[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1).expand_as(real).to(self.device)

        interpolated = alpha * real.data + (1 - alpha) * fake.data
        interpolated = Variable(interpolated, requires_grad=True).to(self.device)

        # Calculate probability of interpolated examples
        prob_interpolated = self.discriminator(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(
            outputs=prob_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones(prob_interpolated.size()).to(self.device),
            create_graph=True,
            retain_graph=True,
        )[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.config.gp_weight * ((gradients_norm - 1) ** 2).mean()

    def _construnct_datatset(self, dataset_size: int):
        n_modes = len(self.config.dataset_params)
        gaussians = [
            numpy.random.normal(loc=loc, scale=scale, size=(dataset_size // n_modes,))
            for loc, scale in self.config.dataset_params
        ]
        data = numpy.concatenate(gaussians).reshape(-1, 1)
        return data.astype(numpy.float32)

    # ===== TRAIN BLOCK =====

    def train_dataloader(self) -> DataLoader:
        dataset = self._construnct_datatset(self.config.train_dataset_size)
        train_loader = DataLoader(
            dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True
        )
        return train_loader

    def training_step(self, real_data: torch.Tensor, batch_idx: int, optimizer_idx: int) -> Dict:
        if optimizer_idx == 0:
            g_loss = self.generator_loss(real_data)
            progress_bar = {"g_loss": g_loss}
            output = {"loss": g_loss, "progress_bar": progress_bar, "log": {"g_loss": g_loss}}
        else:
            d_loss = self.discriminator_loss(real_data)
            progress_bar = {"d_loss": d_loss}
            output = {"loss": d_loss, "progress_bar": progress_bar, "log": {"d_loss": d_loss}}
        output["log"].update(self.compute_metrcis(real_data))
        return output

    def training_epoch_end(self, outputs: List[Dict]) -> Dict:
        logs = {
            "d_loss": torch.Tensor([out["log"]["d_loss"] for out in outputs if "d_loss" in out["log"]]).mean(),
            "g_loss": torch.Tensor([out["log"]["g_loss"] for out in outputs if "g_loss" in out["log"]]).mean(),
            "accuracy": torch.Tensor([out["log"]["acc"] for out in outputs if "acc" in out["log"]]).mean(),
            "pvalue": torch.Tensor([out["log"]["pvalue"] for out in outputs if "pvalue" in out["log"]]).mean(),
        }
        progress_bar = {k: v for k, v in logs.items() if k in ["d_loss", "g_loss", "accuracy", "pvalue"]}
        return {"d_loss": logs["d_loss"], "g_loss": logs["g_loss"], "log": logs, "progress_bar": progress_bar}

    def compute_metrcis(self, real: torch.Tensor) -> Dict:
        with torch.no_grad():
            real = real.to(self.device)
            input_noise = torch.randn(real.shape[0], self.config.generator_config.input_size, device=self.device)
            fake = self.generator(input_noise)
            preds_for_fake = self.discriminator(fake).numpy()
            preds_for_real = self.discriminator(real).numpy()
            correct = numpy.sum(preds_for_fake <= 0.5) + numpy.sum(preds_for_real > 0.5)
            accuracy = correct / (real.shape[0] + fake.shape[0])
            pvalue = stat.ks_2samp(real.numpy().T[0], fake.numpy().T[0])[1]
            return {"pvalue": pvalue, "acc": accuracy}