Exemplo n.º 1
0
def getDiscriminatorModel(netDPath='', ngpu=1):
	netD = Discriminator(ngpu).to(device)
	netD.apply(weights_init)
	if netDPath != '':
	    netD.load_state_dict(torch.load(netDPath))
	print(netD)
	return netD
Exemplo n.º 2
0
def create_discriminator():
    netD = Discriminator(ngpu, nc, ndf).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))

    netD.apply(weights_init)
    return netD
Exemplo n.º 3
0
    def __init__(self, device, num_steps, image_size):
        """
        Arguments:
            device: an instance of 'torch.device'.
            num_steps: an integer, total number of iterations.
            image_size: a tuple of integers (width, height).
        """
        G = Generator(depth=128, num_blocks=16)

        w, h = image_size
        features_size = (w // 16, h // 16)
        # because vgg features have stride 16

        # for pixels
        D1 = Discriminator(3, image_size, depth=64)

        # for features
        D2 = Discriminator(512, features_size, depth=64)

        def weights_init(m):
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                init.normal_(m.weight, std=0.1)
                if m.bias is not None:
                    init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                init.ones_(m.weight)
                init.zeros_(m.bias)

        self.G = G.apply(weights_init).to(device).train()
        self.D1 = D1.apply(weights_init).to(device).train()
        self.D2 = D2.apply(weights_init).to(device).train()

        self.optimizer = {
            'G': optim.Adam(self.G.parameters(), lr=1e-4, betas=(0.5, 0.999)),
            'D1': optim.Adam(self.D1.parameters(), lr=1e-4,
                             betas=(0.5, 0.999)),
            'D2': optim.Adam(self.D2.parameters(), lr=1e-4,
                             betas=(0.5, 0.999)),
        }

        def lambda_rule(i):
            decay = num_steps // 3
            m = 1.0 if i < decay else 1.0 - (i - decay) / (num_steps - decay)
            return max(m, 1e-3)

        self.schedulers = []
        for o in self.optimizer.values():
            self.schedulers.append(LambdaLR(o, lr_lambda=lambda_rule))

        self.gan_loss = GAN()
        self.vgg = Extractor().to(device)
        self.mse_loss = nn.MSELoss()
Exemplo n.º 4
0
def build_network(image_size, z_size, d_conv_dim, d_conv_depth, g_conv_dim, g_conv_depth):
    """
    Creates the discriminator and the generator.
    :param image_size: The size of input and target images.
    :param z_size: The length of the input latent vector, z.
    :param d_conv_dim: The depth of the first convolutional layer of the discriminator.
    :param d_conv_depth: The number of convolutional layers of the discriminator.
    :param g_conv_dim: The depth of the inputs to the *last* transpose convolutional layer of the generator.    
    :param g_conv_depth: The number of convolutional layers of the generator.
    :return: A tuple of discriminator and generator instances.
    """

    # define discriminator and generator
    D = Discriminator(image_size=image_size, in_channels=3, conv_dim=d_conv_dim, depth=d_conv_depth)
    G = Generator(target_size=image_size, out_channels=3, z_size=z_size, conv_dim=g_conv_dim, depth=g_conv_depth)

    # initialize model weights
    D.apply(weights_init_normal)
    G.apply(weights_init_normal)

    return D, G
Exemplo n.º 5
0
def train_gan(args):

    # prepare dataloader
    dataloader = create_data_loader(args)

    # set up device
    device = torch.device('cuda:0' if (
        torch.cuda.is_available() and args.ngpu > 0) else 'cpu')

    # Create & setup generator
    netG = Generator(args).to(device)

    # handle multiple gpus
    if (device.type == 'cuda' and args.ngpu > 1):
        netG = nn.DataParallel(netG, list(range(args.ngpu)))

    # load from checkpoint if available
    if args.netG:
        netG.load_state_dict(torch.load(args.netG))

    # initialize network with random weights
    else:
        netG.apply(weights_init)

    # Create & setup discriminator
    netD = Discriminator(args).to(device)

    # handle multiple gpus
    if (device.type == 'cuda' and args.ngpu > 1):
        netD = nn.DataParallel(netD, list(range(args.ngpu)))

    # load from checkpoint if available
    if args.netD:
        netD.load_state_dict(torch.load(args.netD))

    # initialize network with random weights
    else:
        netD.apply(weights_init)

    # setup up loss & optimizers
    criterion = nn.BCELoss()
    optimizerG = optim.Adam(netG.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))

    # For input of generator in testing
    fixed_noise = torch.randn(64, args.nz, 1, 1, device=device)

    # convention for training
    real_label = 1
    fake_label = 0

    # training data for later analysis
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    # epochs
    num_epochs = 150

    print('Starting Training Loop....')
    # For each epoch
    for e in range(args.num_epochs):
        # for each batch in the dataloader
        for i, data in enumerate(dataloader, 0):
            ########## Training Discriminator ##########
            netD.zero_grad()

            # train with real data
            real_data = data[0].to(device)

            # make labels
            batch_size = real_data.size(0)
            labels = torch.full((batch_size, ), real_label, device=device)

            # forward pass real data through D
            real_outputD = netD(real_data).view(-1)

            # calc error on real data
            errD_real = criterion(real_outputD, labels)

            # calc grad
            errD_real.backward()
            D_x = real_outputD.mean().item()

            # train with fake data
            noise = torch.randn(batch_size, args.nz, 1, 1, device=device)
            fake_data = netG(noise)
            labels.fill_(fake_label)

            # classify fake
            fake_outputD = netD(fake_data.detach()).view(-1)

            # calc error on fake data
            errD_fake = criterion(fake_outputD, labels)

            # calc grad
            errD_fake.backward()
            D_G_z1 = fake_outputD.mean().item()

            # add all grad and update D
            errD = errD_real + errD_fake
            optimizerD.step()

            ########################################
            ########## Training Generator ##########
            netG.zero_grad()

            # since aim is fooling the netD, labels should be flipped
            labels.fill_(real_label)

            # forward pass with updated netD
            fake_outputD = netD(fake_data).view(-1)

            # calc error
            errG = criterion(fake_outputD, labels)

            # calc grad
            errG.backward()

            D_G_z2 = fake_outputD.mean().item()

            # update G
            optimizerG.step()

            ########################################

            # output training stats
            if i % 500 == 0:
                print(f'[{e+1}/{args.num_epochs}][{i+1}/{len(dataloader)}]\
  					\tLoss_D:{errD.item():.4f}\
  					\tLoss_G:{errG.item():.4f}\
  					\tD(x):{D_x:.4f}\
  					\tD(G(z)):{D_G_z1:.4f}/{D_G_z2:.4f}')

            # for later plot
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # generate fake image on fixed noise for comparison
            if ((iters % 500 == 0) or ((e == args.num_epochs - 1) and
                                       (i == len(dataloader) - 1))):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                    img_list.append(
                        vutils.make_grid(fake, padding=2, normalize=True))
            iters += 1

        if e % args.save_every == 0:
            # save at args.save_every epoch
            torch.save(netG.state_dict(), args.outputG)
            torch.save(netD.state_dict(), args.outputD)
            print(f'Made a New Checkpoint for {e+1}')
    # return training data for analysis
    return img_list, G_losses, D_losses
Exemplo n.º 6
0
class AttnGAN:
    def __init__(self, damsm, device=DEVICE):
        self.gen = Generator(device)
        self.disc = Discriminator(device)
        self.damsm = damsm.to(device)
        self.damsm.txt_enc.eval(), self.damsm.img_enc.eval()
        freeze_params_(self.damsm.txt_enc), freeze_params_(self.damsm.img_enc)

        self.device = device
        self.gen.apply(init_weights), self.disc.apply(init_weights)

        self.gen_optimizer = torch.optim.Adam(self.gen.parameters(),
                                              lr=GENERATOR_LR,
                                              betas=(0.5, 0.999))

        self.discriminators = [self.disc.d64, self.disc.d128, self.disc.d256]
        self.disc_optimizers = [
            torch.optim.Adam(d.parameters(),
                             lr=DISCRIMINATOR_LR,
                             betas=(0.5, 0.999)) for d in self.discriminators
        ]

    #@torch.no_grad()
    def train(self,
              dataset,
              epoch,
              batch_size=GAN_BATCH,
              test_sample_every=5,
              hist_avg=False,
              evaluator=None):

        start_time = time.strftime("%Y-%m-%d-%H-%M", time.gmtime())
        os.makedirs(f'{OUT_DIR}/{start_time}')

        # print('cun')
        # for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True):
        #     self.gen.eval()
        #     generated_samples = [resolution.unsqueeze(0) for resolution in self.sample_test_set(dataset)]
        #     self._save_generated(generated_samples, e, f'{OUT_DIR}/{start_time}')
        #
        #     return

        if hist_avg:
            avg_g_params = deepcopy(list(p.data
                                         for p in self.gen.parameters()))

        loader_config = {
            'batch_size': batch_size,
            'shuffle': True,
            'drop_last': True,
            'collate_fn': dataset.collate_fn
        }

        train_loader = DataLoader(dataset.train, **loader_config)

        metrics = {
            'IS': [],
            'FID': [],
            'loss': {
                'g': [],
                'd': []
            },
            'accuracy': {
                'real': [],
                'fake': [],
                'mismatched': [],
                'unconditional_real': [],
                'unconditional_fake': []
            }
        }

        if evaluator is not None:
            evaluator = evaluator(dataset, self.damsm.img_enc.inception_model,
                                  batch_size, self.device)

        noise = torch.FloatTensor(batch_size, D_Z).to(self.device)
        gen_updates = 0

        self.disc.train()

        for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True):
            self.gen.train(), self.disc.train()
            g_loss = 0
            w_loss = 0
            s_loss = 0
            kl_loss = 0
            g_stage_loss = np.zeros(3, dtype=float)
            d_loss = np.zeros(3, dtype=float)
            real_acc = np.zeros(3, dtype=float)
            fake_acc = np.zeros(3, dtype=float)
            mismatched_acc = np.zeros(3, dtype=float)
            uncond_real_acc = np.zeros(3, dtype=float)
            uncond_fake_acc = np.zeros(3, dtype=float)
            disc_skips = np.zeros(3, dtype=int)

            train_pbar = tqdm(train_loader,
                              desc='Training',
                              leave=False,
                              dynamic_ncols=True)
            for batch in train_pbar:
                real_imgs = [batch['img64'], batch['img128'], batch['img256']]

                with torch.no_grad():
                    word_embs, sent_embs = self.damsm.txt_enc(batch['caption'])
                attn_mask = torch.tensor(batch['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]

                # Generate images
                noise.data.normal_(0, 1)
                generated, att, mu, logvar = self.gen(noise, sent_embs,
                                                      word_embs, attn_mask)

                # Discriminator loss (with label smoothing)
                batch_d_loss, batch_real_acc, batch_fake_acc, batch_mismatched_acc, batch_uncond_real_acc, batch_uncond_fake_acc, batch_disc_skips = self.discriminator_step(
                    real_imgs, generated, sent_embs, 0.1)

                d_grad_norm = [grad_norm(d) for d in self.discriminators]

                d_loss += batch_d_loss
                real_acc += batch_real_acc
                fake_acc += batch_fake_acc
                mismatched_acc += batch_mismatched_acc
                uncond_real_acc += batch_uncond_real_acc
                uncond_fake_acc += batch_uncond_fake_acc
                disc_skips += batch_disc_skips

                # Generator loss
                batch_g_losses = self.generator_step(generated, word_embs,
                                                     sent_embs, mu, logvar,
                                                     batch['label'])
                g_total, batch_g_stage_loss, batch_w_loss, batch_s_loss, batch_kl_loss = batch_g_losses
                g_stage_loss += batch_g_stage_loss
                w_loss += batch_w_loss
                s_loss += (batch_s_loss)
                kl_loss += (batch_kl_loss)
                gen_updates += 1

                avg_g_loss = g_total.item() / batch_size
                g_loss += float(avg_g_loss)

                if hist_avg:
                    for p, avg_p in zip(self.gen.parameters(), avg_g_params):
                        avg_p.mul_(0.999).add_(0.001, p.data)

                    if gen_updates % 1000 == 0:
                        tqdm.write(
                            'Replacing generator weights with their moving average'
                        )
                        for p, avg_p in zip(self.gen.parameters(),
                                            avg_g_params):
                            p.data.copy_(avg_p)

                train_pbar.set_description(
                    f'Training (G: {grad_norm(self.gen):.2f}  '
                    f'D64: {d_grad_norm[0]:.2f}  '
                    f'D128: {d_grad_norm[1]:.2f}  '
                    f'D256: {d_grad_norm[2]:.2f})')

            batches = len(train_loader)

            g_loss /= batches
            g_stage_loss /= batches
            w_loss /= batches
            s_loss /= batches
            kl_loss /= batches
            d_loss /= batches
            real_acc /= batches
            fake_acc /= batches
            mismatched_acc /= batches
            uncond_real_acc /= batches
            uncond_fake_acc /= batches

            metrics['loss']['g'].append(g_loss)
            metrics['loss']['d'].append(d_loss)
            metrics['accuracy']['real'].append(real_acc)
            metrics['accuracy']['fake'].append(fake_acc)
            metrics['accuracy']['mismatched'].append(mismatched_acc)
            metrics['accuracy']['unconditional_real'].append(uncond_real_acc)
            metrics['accuracy']['unconditional_fake'].append(uncond_fake_acc)

            sep = '_' * 10
            tqdm.write(f'{sep}Epoch {e}{sep}')

            if e % test_sample_every == 0:
                self.gen.eval()
                generated_samples = [
                    resolution.unsqueeze(0)
                    for resolution in self.sample_test_set(dataset)
                ]
                self._save_generated(generated_samples, e,
                                     f'{OUT_DIR}/{start_time}')

                if evaluator is not None:
                    scores = evaluator.evaluate(self)
                    for k, v in scores.items():
                        metrics[k].append(v)
                        tqdm.write(f'{k}: {v:.2f}')

            tqdm.write(
                f'Generator avg loss: total({g_loss:.3f})  '
                f'stage0({g_stage_loss[0]:.3f})  stage1({g_stage_loss[1]:.3f})  stage2({g_stage_loss[2]:.3f})  '
                f'w({w_loss:.3f})  s({s_loss:.3f})  kl({kl_loss:.3f})')

            for i, _ in enumerate(self.discriminators):
                tqdm.write(f'Discriminator{i} avg: '
                           f'loss({d_loss[i]:.3f})  '
                           f'r-acc({real_acc[i]:.3f})  '
                           f'f-acc({fake_acc[i]:.3f})  '
                           f'm-acc({mismatched_acc[i]:.3f})  '
                           f'ur-acc({uncond_real_acc[i]:.3f})  '
                           f'uf-acc({uncond_fake_acc[i]:.3f})  '
                           f'skips({disc_skips[i]})')

        return metrics

    def sample_test_set(self,
                        dataset,
                        nb_samples=8,
                        nb_captions=2,
                        noise_variations=2):
        subset = dataset.test
        sample_indices = np.random.choice(len(subset),
                                          nb_samples,
                                          replace=False)
        cap_indices = np.random.choice(10, nb_captions, replace=False)
        texts = [
            subset.data[f'caption_{cap_idx}'].iloc[sample_idx]
            for sample_idx in sample_indices for cap_idx in cap_indices
        ]

        generated_samples = [
            self.generate_from_text(texts, dataset)
            for _ in range(noise_variations)
        ]
        combined_img64 = torch.FloatTensor()
        combined_img128 = torch.FloatTensor()
        combined_img256 = torch.FloatTensor()

        for noise_variant in generated_samples:
            noise_var_img64 = torch.FloatTensor()
            noise_var_img128 = torch.FloatTensor()
            noise_var_img256 = torch.FloatTensor()
            for i in range(nb_samples):
                # rows: samples, columns: captions * noise variants
                row64 = torch.cat([
                    noise_variant[0][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                  dim=-1).cpu()
                row128 = torch.cat([
                    noise_variant[1][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                   dim=-1).cpu()
                row256 = torch.cat([
                    noise_variant[2][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                   dim=-1).cpu()
                noise_var_img64 = torch.cat([noise_var_img64, row64], dim=-2)
                noise_var_img128 = torch.cat([noise_var_img128, row128],
                                             dim=-2)
                noise_var_img256 = torch.cat([noise_var_img256, row256],
                                             dim=-2)
            combined_img64 = torch.cat([combined_img64, noise_var_img64],
                                       dim=-1)
            combined_img128 = torch.cat([combined_img128, noise_var_img128],
                                        dim=-1)
            combined_img256 = torch.cat([combined_img256, noise_var_img256],
                                        dim=-1)

        return combined_img64, combined_img128, combined_img256

    @staticmethod
    def KL_loss(mu, logvar):
        loss = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
        loss = torch.mean(loss).mul_(-0.5)
        return loss

    def generator_step(self, generated_imgs, word_embs, sent_embs, mu, logvar,
                       class_labels):
        self.gen.zero_grad()
        avg_stage_g_loss = [0, 0, 0]

        local_features, global_features = self.damsm.img_enc(
            generated_imgs[-1])
        batch_size = sent_embs.size(0)
        match_labels = torch.LongTensor(range(batch_size)).to(self.device)

        w1_loss, w2_loss, _ = self.damsm.words_loss(local_features, word_embs,
                                                    class_labels, match_labels)
        w_loss = (w1_loss + w2_loss) * LAMBDA

        s1_loss, s2_loss = self.damsm.sentence_loss(global_features, sent_embs,
                                                    class_labels, match_labels)
        s_loss = (s1_loss + s2_loss) * LAMBDA

        kl_loss = self.KL_loss(mu, logvar)

        g_total = w_loss + s_loss + kl_loss

        for i, d in enumerate(self.discriminators):
            features = d(generated_imgs[i])
            fake_logits = d.logit(features, sent_embs)

            real_labels = torch.ones_like(fake_logits).to(self.device)

            disc_error = F.binary_cross_entropy_with_logits(
                fake_logits, real_labels)

            uncond_fake_logits = d.logit(features)
            uncond_disc_error = F.binary_cross_entropy_with_logits(
                uncond_fake_logits, real_labels)

            stage_loss = disc_error + uncond_disc_error
            avg_stage_g_loss[i] = stage_loss.item() / batch_size
            g_total += stage_loss

        g_total.backward()
        self.gen_optimizer.step()

        return g_total, avg_stage_g_loss, w_loss.item(
        ) / batch_size, s_loss.item() / batch_size, kl_loss.item()

    def discriminator_step(self,
                           real_imgs,
                           generated_imgs,
                           sent_embs,
                           label_smoothing,
                           skip_acc_threshold=0.9,
                           p_flip=0.05,
                           halting=False):
        self.disc.zero_grad()
        batch_size = sent_embs.size(0)

        avg_d_loss = [0, 0, 0]
        real_accuracy = [0, 0, 0]
        fake_accuracy = [0, 0, 0]
        mismatched_accuracy = [0, 0, 0]
        uncond_real_accuracy = [0, 0, 0]
        uncond_fake_accuracy = [0, 0, 0]
        skipped = [0, 0, 0]

        for i, d in enumerate(self.discriminators):
            real_features = d(real_imgs[i].to(self.device))
            fake_features = d(generated_imgs[i].detach())

            real_logits = d.logit(real_features, sent_embs)

            real_labels = torch.full_like(real_logits,
                                          1 - label_smoothing).to(self.device)
            fake_labels = torch.zeros_like(real_logits,
                                           dtype=torch.float).to(self.device)

            # flip_mask = torch.Tensor(real_labels.size()).bernoulli_(p_flip).type(torch.bool)
            # real_labels[flip_mask], fake_labels[flip_mask] = fake_labels[flip_mask], real_labels[flip_mask]

            real_error = F.binary_cross_entropy_with_logits(
                real_logits, real_labels)
            # Real images should be classified as real
            real_accuracy[i] = (real_logits >=
                                0).sum().item() / real_logits.numel()

            fake_logits = d.logit(fake_features, sent_embs)
            fake_error = F.binary_cross_entropy_with_logits(
                fake_logits, fake_labels)
            # Generated images should be classified as fake
            fake_accuracy[i] = (fake_logits <
                                0).sum().item() / fake_logits.numel()

            mismatched_logits = d.logit(real_features,
                                        rotate_tensor(sent_embs, 1))
            mismatched_error = F.binary_cross_entropy_with_logits(
                mismatched_logits, fake_labels)
            # Images with mismatched descriptions should be classified as fake
            mismatched_accuracy[i] = (mismatched_logits < 0).sum().item(
            ) / mismatched_logits.numel()

            uncond_real_logits = d.logit(real_features)
            uncond_real_error = F.binary_cross_entropy_with_logits(
                uncond_real_logits, real_labels)
            uncond_real_accuracy[i] = (uncond_real_logits >= 0).sum().item(
            ) / uncond_real_logits.numel()

            uncond_fake_logits = d.logit(fake_features)
            uncond_fake_error = F.binary_cross_entropy_with_logits(
                uncond_fake_logits, fake_labels)
            uncond_fake_accuracy[i] = (uncond_fake_logits < 0).sum().item(
            ) / uncond_fake_logits.numel()

            error = (real_error + uncond_real_error) / 2 + (
                fake_error + uncond_fake_error + mismatched_error) / 3

            if not halting or fake_accuracy[i] + real_accuracy[
                    i] < skip_acc_threshold * 2:
                error.backward()
                self.disc_optimizers[i].step()
            else:
                skipped[i] = 1

            avg_d_loss[i] = error.item() / batch_size

        return avg_d_loss, real_accuracy, fake_accuracy, mismatched_accuracy, uncond_real_accuracy, uncond_fake_accuracy, skipped

    def generate_from_text(self, texts, dataset, noise=None):
        encoded = [dataset.train.encode_text(t) for t in texts]
        generated = self.generate_from_encoded_text(encoded, dataset, noise)
        return generated

    def generate_from_encoded_text(self, encoded, dataset, noise=None):
        with torch.no_grad():
            w_emb, s_emb = self.damsm.txt_enc(encoded)
            attn_mask = torch.tensor(encoded).to(
                self.device) == dataset.vocab[END_TOKEN]
            if noise is None:
                noise = torch.FloatTensor(len(encoded), D_Z).to(self.device)
                noise.data.normal_(0, 1)
            generated, att, mu, logvar = self.gen(noise, s_emb, w_emb,
                                                  attn_mask)
        return generated

    def _save_generated(self, generated, epoch, out_dir=OUT_DIR):
        nb_samples = generated[0].size(0)
        save_dir = f'{out_dir}/epoch_{epoch:03}'
        os.makedirs(save_dir)

        for i in range(nb_samples):
            save_image(generated[0][i],
                       f'{save_dir}/{i}_64.jpg',
                       normalize=True,
                       range=(-1, 1))
            save_image(generated[1][i],
                       f'{save_dir}/{i}_128.jpg',
                       normalize=True,
                       range=(-1, 1))
            save_image(generated[2][i],
                       f'{save_dir}/{i}_256.jpg',
                       normalize=True,
                       range=(-1, 1))

    def save(self, name, save_dir=GAN_MODEL_DIR, metrics=None):
        os.makedirs(save_dir, exist_ok=True)
        torch.save(self.gen.state_dict(), f'{save_dir}/{name}_generator.pt')
        torch.save(self.disc.state_dict(),
                   f'{save_dir}/{name}_discriminator.pt')
        if metrics is not None:
            with open(f'{save_dir}/{name}_metrics.json', 'w') as f:
                metrics = pre_json_metrics(metrics)
                json.dump(metrics, f)

    def load_(self, name, load_dir=GAN_MODEL_DIR):
        self.gen.load_state_dict(torch.load(f'{load_dir}/{name}_generator.pt'))
        self.disc.load_state_dict(
            torch.load(f'{load_dir}/{name}_discriminator.pt'))
        self.gen.eval(), self.disc.eval()

    @staticmethod
    def load(name, damsm, load_dir=GAN_MODEL_DIR, device=DEVICE):
        attngan = AttnGAN(damsm, device=device)
        attngan.load_(name, load_dir)
        return attngan

    def validate_test_set(self,
                          dataset,
                          batch_size=GAN_BATCH,
                          save_dir=f'{OUT_DIR}/test_samples'):
        os.makedirs(save_dir, exist_ok=True)

        loader = DataLoader(dataset.test,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=False,
                            collate_fn=dataset.collate_fn)
        loader = tqdm(loader,
                      dynamic_ncols=True,
                      leave=True,
                      desc='Generating samples for test set')

        self.gen.eval()
        with torch.no_grad():
            i = 0
            for batch in loader:
                word_embs, sent_embs = self.damsm.txt_enc(batch['caption'])
                attn_mask = torch.tensor(batch['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]
                noise = torch.FloatTensor(len(batch['caption']),
                                          D_Z).to(self.device)
                noise.data.normal_(0, 1)
                generated, att, mu, logvar = self.gen(noise, sent_embs,
                                                      word_embs, attn_mask)

                for img in generated[-1]:
                    save_image(img,
                               f'{save_dir}/{i}.jpg',
                               normalize=True,
                               range=(-1, 1))
                    i += 1

    def get_d_score(self, imgs, sent_embs):
        d = self.disc.d256
        features = d(imgs.to(self.device))
        scores = d.logit(features, sent_embs)
        return scores

    def accept_prob(self, score1, score2):
        return min(1, (1 / score1 - 1) / (1 / score2 - 1))

    def d_scores_test(self, dataset):
        with torch.no_grad():
            loader = DataLoader(dataset.test,
                                batch_size=20,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            scores = []
            d = self.disc.d256
            for b in loader:
                img = b['img256'].to(self.device)
                f = d(img)
                l = d.logit(f)
                scores.append(torch.sigmoid(l))
            scores = [x.item() for s in scores for x in s.reshape(-1)]
        return scores

    def z_test(self, scores, labels):
        labels = np.array(labels)
        scores = np.array(scores)
        num = np.sum(labels - scores)
        denom = np.sqrt(np.sum(scores * (1 - scores)))
        return num / denom

    def d_scores_gen(self, dataset):
        with torch.no_grad():
            loader = DataLoader(dataset.test,
                                batch_size=20,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            scores = []
            d = self.disc.d256
            for b in loader:
                noise = torch.FloatTensor(len(b['caption']),
                                          D_Z).to(self.device)
                noise.data.normal_(0, 1)
                word_embs, sent_embs = self.damsm.txt_enc(b['caption'])
                attn_mask = torch.tensor(b['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]
                generated, _, _, _ = self.gen(noise, sent_embs, word_embs,
                                              attn_mask)

                f = d(generated[-1])
                l = d.logit(f)
                scores.append(torch.sigmoid(l))
            scores = [x.item() for s in scores for x in s.reshape(-1)]
        return scores

    def mh_sample(self, dataset, k, save_dir='test_samples', batch=GAN_BATCH):
        evaluator = IS_FID_Evaluator(dataset,
                                     self.damsm.img_enc.inception_model, batch,
                                     self.device)
        # self.disc.d256.train()
        with torch.no_grad():
            l = len(dataset.test)
            score_real = self.d_scores_test(dataset)
            score_gen = self.d_scores_gen(dataset)
            print(np.mean(score_real))
            print(np.mean(score_gen))
            portion = -l // 5
            score_test = score_real[:portion] + score_gen[:portion]
            label_test = [1] * (len(score_test) //
                                2) + [0] * (len(score_test) // 2)

            print('Z test before calibration: ',
                  self.z_test(torch.tensor(score_test), label_test))

            score_real_calib = score_real[portion:]
            score_gen_calib = score_gen[portion:]
            # score_calib = score_real_calib + score_gen_calib
            score_calib = score_gen_calib + score_real_calib
            label_calib = len(score_gen_calib) * [0] + len(
                score_real_calib) * [1]

            cal_clf = LogisticRegression()
            cal_clf.fit(np.array(score_calib).reshape(-1, 1), label_calib)

            score_pred = cal_clf.predict_proba(
                np.array(score_test).reshape(-1, 1))[:, 1]
            print('Score pred avg: ', np.mean(score_pred))
            test_pred = cal_clf.predict(np.array(score_test).reshape(-1, 1))

            print('Z test after calibration: ',
                  self.z_test(score_pred, label_test))
            print('Accuracy: ',
                  sum((test_pred == label_test)) / len(test_pred))

            os.makedirs(save_dir, exist_ok=True)
            loader = DataLoader(dataset.test,
                                batch_size=1,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            loader = tqdm(loader,
                          dynamic_ncols=True,
                          leave=True,
                          desc='Generating samples for test set')

            imgs = []
            true_probs = 0
            noaccept = 0
            for i, sample in enumerate(loader):
                if i > l - (l // 10):
                    continue
                word_embs, sent_embs = self.damsm.txt_enc(sample['caption'])
                attn_mask = torch.tensor(sample['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]

                img_chain = []
                while len(img_chain) < k:
                    noise = torch.FloatTensor(batch, D_Z).to(self.device)
                    noise.data.normal_(0, 1)
                    generated, _, _, _ = self.gen(
                        noise, sent_embs.repeat(batch, 1),
                        word_embs.repeat(batch, 1, 1),
                        attn_mask.repeat(batch, 1))

                    for img in generated[-1]:
                        img_chain.append(img)

                img_chain = img_chain[:k]
                img_chain = torch.stack(img_chain).to(self.device)

                score_chain = []
                d_loader = DataLoader(img_chain,
                                      batch_size=batch,
                                      shuffle=False,
                                      drop_last=False)
                for d_batch in d_loader:
                    scores = self.get_d_score(d_batch,
                                              sent_embs.repeat(batch, 1))
                    scores = scores.reshape(-1, 1).cpu().numpy()
                    scores = cal_clf.predict_proba(scores)[:, 1]
                    for s in scores:
                        score_chain.append(s)
                chosen = 0
                for j, s in enumerate(score_chain[1:], 1):
                    alpha = self.accept_prob(score_chain[chosen], s)
                if np.random.rand() < alpha:
                    chosen = j

                if chosen == 0:
                    imgs.append(img_chain[torch.tensor(
                        score_chain[1:]).argmax()].cpu())
                    noaccept += 1
                else:
                    imgs.append(img_chain[chosen].cpu())
                true_probs += score_chain[0]

            print(noaccept)
            print(true_probs / len(dataset.test))
            mu_real, sig_real = evaluator.mu_real, evaluator.sig_real
            mu_fake, sig_fake = activation_statistics(
                self.damsm.img_enc.inception_model, imgs)
            print('FID: ', frechet_dist(mu_real, sig_real, mu_fake, sig_fake))
            return imgs
Exemplo n.º 7
0
epochs = 200
save_imgs = 30
lambda_p = 50
lambda_gp = 10
n_critic = 5
dataset = "CMP_facade_DB_base"
out_dir = "results8-seg2img"
generator = Generator2(n_filters=32, kernel_size=3, l=4)
discriminator = Discriminator(h, w, c)
pix_loss = nn.MSELoss()
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    pix_loss = pix_loss.cuda()
generator.apply(weight_init)
discriminator.apply(weight_init)

os.makedirs(out_dir, exist_ok=True)

transforms_ = [
    transforms.Resize((h, w), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

train_dataloader = torch.utils.data.DataLoader(
    ImageDataset("../data/{}".format(dataset), transforms_=transforms_),
    batch_size=batch_size,
    shuffle=True,
)
val_dataloader = torch.utils.data.DataLoader(
Exemplo n.º 8
0
class GAIL:
    def __init__(self,
                 exp_dir,
                 exp_thresh,
                 state_dim,
                 action_dim,
                 learn_rate,
                 betas,
                 _device,
                 _gamma,
                 load_weights=False):
        """
            exp_dir : directory containing the expert episodes
         exp_thresh : parameter to control number of episodes to load 
                      as expert based on returns (lower means more episodes)
          state_dim : dimesnion of state 
         action_dim : dimesnion of action
         learn_rate : learning rate for optimizer 
            _device : GPU or cpu
            _gamma  : discount factor
     _load_weights  : load weights from directory
        """

        # storing runtime device
        self.device = _device

        # discount factor
        self.gamma = _gamma

        # Expert trajectory
        self.expert = ExpertTrajectories(exp_dir, exp_thresh, gamma=self.gamma)

        # Defining the actor and its optimizer
        self.actor = ActorNetwork(state_dim).to(self.device)
        self.optim_actor = torch.optim.Adam(self.actor.parameters(),
                                            lr=learn_rate,
                                            betas=betas)

        # Defining the discriminator and its optimizer
        self.disc = Discriminator(state_dim, action_dim).to(self.device)
        self.optim_disc = torch.optim.Adam(self.disc.parameters(),
                                           lr=learn_rate,
                                           betas=betas)

        if not load_weights:
            self.actor.apply(init_weights)
            self.disc.apply(init_weights)
        else:
            self.load()

        # Loss function crtiterion
        self.criterion = torch.nn.BCELoss()

    def get_action(self, state):
        """
            obtain action for a given state using actor network 
        """
        state = torch.tensor(state, dtype=torch.float,
                             device=self.device).view(1, -1)
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, n_iter, batch_size=100):
        """
            train discriminator and actor for mini-batch
        """
        # memory to store
        disc_losses = np.zeros(n_iter, dtype=np.float)
        act_losses = np.zeros(n_iter, dtype=np.float)

        for i in range(n_iter):

            # Get expert state and actions batch
            exp_states, exp_actions = self.expert.sample(batch_size)
            exp_states = torch.FloatTensor(exp_states).to(self.device)
            exp_actions = torch.FloatTensor(exp_actions).to(self.device)

            # Get state, and actions using actor
            states, _ = self.expert.sample(batch_size)
            states = torch.FloatTensor(states).to(self.device)
            actions = self.actor(states)
            '''
                train the discriminator
            '''
            self.optim_disc.zero_grad()

            # label tensors
            exp_labels = torch.full((batch_size, 1), 1, device=self.device)
            policy_labels = torch.full((batch_size, 1), 0, device=self.device)

            # with expert transitions
            prob_exp = self.disc(exp_states, exp_actions)
            exp_loss = self.criterion(prob_exp, exp_labels)

            # with policy actor transitions
            prob_policy = self.disc(states, actions.detach())
            policy_loss = self.criterion(prob_policy, policy_labels)

            # use backprop
            disc_loss = exp_loss + policy_loss
            disc_losses[i] = disc_loss.mean().item()

            disc_loss.backward()
            self.optim_disc.step()
            '''
                train the actor
            '''
            self.optim_actor.zero_grad()
            loss_actor = -self.disc(states, actions)
            act_losses[i] = loss_actor.mean().detach().item()

            loss_actor.mean().backward()
            self.optim_actor.step()

        print("Finished training minibatch")

        return act_losses, disc_losses

    def save(
            self,
            directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights',
            name='GAIL'):
        torch.save(self.actor.state_dict(),
                   '{}/{}_actor.pth'.format(directory, name))
        torch.save(self.disc.state_dict(),
                   '{}/{}_discriminator.pth'.format(directory, name))

    def load(
            self,
            directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights',
            name='GAIL'):
        print(os.getcwd())
        self.actor.load_state_dict(
            torch.load('{}/{}_actor.pth'.format(directory, name)))
        self.disc.load_state_dict(
            torch.load('{}/{}_discriminator.pth'.format(directory, name)))

    def set_mode(self, mode="train"):

        if mode == "train":
            self.actor.train()
            self.disc.train()
        else:
            self.actor.eval()
            self.disc.eval()
Exemplo n.º 9
0

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


G = Generator().to(device)
G.apply(weights_init)

D = Discriminator().to(device)
D.apply(weights_init)

# Training the DCGANs

criterion = nn.BCELoss()
optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

Dis_loss = []
gen_loss = []
for epoch in range(25):
    print("***************Epoch is *******************", epoch + 1)
    for i, data in enumerate(dataloader, 0):
        D.zero_grad()
        real, _ = data
        input = Variable(real).to(device)
Exemplo n.º 10
0
    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netAE.apply(weights_init)

    if plot_machines == True:
        # Print the model
        print(netAE)

    # Create the Discriminators
    netD = Discriminator(ngpu, activation).to(device)

    netSD = Discriminator(ngpu, activation).to(device)

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netD.apply(weights_init)

    if plot_machines == True:
        # Print the model
        print(netD)

    # Plot some training images
    if plot_some_images == True:
        real_batch = next(iter(dataloader))
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title("Training Images")
        plt.imshow(np.transpose(vutils.make_grid(real_batch[0][0].to(device)[:64],\
                                                 padding=2, normalize=True).cpu(),\
                                                 (1,2,0)))
        plt.figure(figsize=(8, 8))
Exemplo n.º 11
0
def main(cfg):

    # 再現性のためにシード値をセット
    manualSeed = 999

    # manualSeed = random.randint(1, 10000) # 新しい結果がほしい場合に使用

    print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    # datsetの作成
    dataset = make_dataset(cfg.dataroot, cfg.image_size)

    # データローダの作成
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=cfg.batch_size,
                                             shuffle=True,
                                             num_workers=cfg.workers)

    # どのデバイスで実行するか決定
    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and cfg.ngpu > 0) else "cpu")

    # 訓練画像をプロット
    real_batch = next(iter(dataloader))
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(real_batch[0].to(device)[:64],
                             padding=2,
                             normalize=True).cpu(),
            (1, 2, 0),
        ))
    plt.savefig("training_img")

    # Generator作成
    netG = Generator(cfg.ngpu, cfg.nz, cfg.ngf, cfg.nc).to(device)

    # マルチGPUを望むなら
    if (device.type == "cuda") and (cfg.ngpu > 1):
        netG = nn.DataParallel(netG, list(range(cfg.ngpu)))

    # 重みの初期化関数を適用
    netG.apply(weights_init)

    # モデルの印字
    print(netG)

    # Discriminator作成
    netD = Discriminator(cfg.ngpu, cfg.ndf, cfg.nc).to(device)

    # マルチGPUを望むなら
    if (device.type == "cuda") and (cfg.ngpu > 1):
        netD = nn.DataParallel(netD, list(range(cfg.ngpu)))

    # 重みの初期化関数を適用
    netD.apply(weights_init)

    # モデルの印字
    print(netD)

    # 損失関数の定義
    criterion = nn.BCELoss()

    # 潜在ベクトルを作成 Generatorの進歩を可視化するため
    fixed_noise = torch.randn(64, cfg.nz, 1, 1, device=device)

    # 学習中の本物と偽物のラベルを作成
    real_label = 1
    fake_label = 0

    # 最適化関数Adamを設定
    optimizerD = optim.Adam(netD.parameters(),
                            lr=cfg.lr,
                            betas=(cfg.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=cfg.lr,
                            betas=(cfg.beta1, 0.999))

    # 学習ループ

    # 結果を保存しておくリスト
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(cfg.num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):

            ###############
            # 1. Discriminatorの更新
            ###############

            # 本物のバッチの学習
            netD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ), real_label, device=device)

            # 本物のバッチをDiscriminatorへ
            output = netD(real_cpu).view(-1)

            # 損失値の計算
            errD_real = criterion(output, label)

            # bpによる勾配計算
            errD_real.backward()
            D_x = output.mean().item()

            # 偽物のバッチ学習
            # 潜在ベクトル生成
            noise = torch.randn(b_size, cfg.nz, 1, 1, device=device)

            # 偽画像の生成
            fake = netG(noise)
            label.fill_(fake_label)

            # 偽画像の分類
            output = netD(fake.detach()).view(-1)

            # 偽画像の損失値の計算
            errD_fake = criterion(output, label)

            # 勾配の計算
            errD_fake.backward()
            D_G_z1 = output.mean().item()

            # 本物と偽物の勾配を足す
            errD = errD_real + errD_fake

            # Discriminatorの更新
            optimizerD.step()

            ##########
            # 2. Generatorの更新
            ##########
            netG.zero_grad()
            label.fill_(real_label)

            output = netD(fake).view(-1)
            # GenerotorのLoss計算
            errG = criterion(output, label)

            # Generatorの勾配計算
            errG.backward()
            D_G_z2 = output.mean().item()

            # Generatorの更新
            optimizerG.step()

            # 学習の状態を出力

            if i % 50 == 0:
                print(
                    "[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f"
                    % (
                        epoch,
                        cfg.num_epochs,
                        i,
                        len(dataloader),
                        errD.item(),
                        errG.item(),
                        D_x,
                        D_G_z1,
                        D_G_z2,
                    ))

            # Lossを保存する
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Generatorの動作確認と出力画像を保存
            if (iters % 500 == 0) or ((epoch == cfg.num_epochs - 1) and
                                      (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()

                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    plt.savefig("Genarator_Discriminator_Loss.png")

    # データローダから本物の画像を取得
    real_batch = next(iter(dataloader))

    # Real images のplot
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(real_batch[0].to(device)[:64],
                             padding=5,
                             normalize=True).cpu(),
            (1, 2, 0),
        ))

    # 最後のエポックの偽画像を表示
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    plt.show()
    plt.savefig("result_img.png")
Exemplo n.º 12
0
        np.transpose(
            torchvision.utils.make_grid(real_batch[0].to(device)[:64],
                                        padding=2,
                                        normalize=True).cpu(), (1, 2, 0)))

if opt.train:

    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(disc.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(gen.parameters(), lr=lr, betas=(beta1, 0.999))

    gen = Generator(num_z, num_feat_gen, num_channels).to(device)
    disc = Discriminator(num_channels, num_feat_disc).to(device)

    gen.apply(weights_init)
    disc.apply(weights_init)

    print(gen)
    print(disc)

    criterion = nn.BCELoss()

    G_losses, D_losses, img_list = train(5, num_z, dataloader, \
        gen, disc, device, criterion, optimizerG, optimizerD)

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
Exemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str, choices=['xcad', 'mnist'], help="type of real images")
    parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval between image sampling")
    parser.add_argument("--ckp_interval", type=int, default=5, help="interval between model saving")
    opt = parser.parse_args()
    print(opt)

    images_path = "logs/images/{}".format(opt.data)
    os.makedirs(images_path, exist_ok=True)
    ckps_path = "logs/ckps/{}".format(opt.data)
    os.makedirs(ckps_path, exist_ok=True)

    cuda = True if torch.cuda.is_available() else False

    # Loss function
    adversarial_loss = torch.nn.BCELoss()

    # Initialize generator and discriminator
    generator = Generator(opt)
    discriminator = Discriminator(opt)

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    # Configure data loader
    os.makedirs("../data", exist_ok=True)
    if opt.data == 'mnist':
        transform = transforms.Compose([
            transforms.Resize(opt.img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        dataset = datasets.MNIST(
            "../data",
            train=True,
            download=True,
            transform=transform,
        )
    elif opt.data == 'xcad':
        transform = transforms.Compose([
            transforms.Resize(opt.img_size),
            # # there is a bug in older PIL, thus, to deal with grayscale images
            # # that have only one channel, add fill=(0, )
            # transforms.RandomRotation(180, fill=(0, )),
            transforms.RandomRotation(180),  # add diversity
            transforms.RandomHorizontalFlip(0.5),  # add diversity
            transforms.RandomVerticalFlip(0.5),  # add diversity
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        dataset = XCADImageDataset(
            "../data",
            transform=transform
        )
    else:
        raise ValueError("invalid args --data")
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    # ----------
    #  Training
    # ----------

    for epoch in range(opt.n_epochs):
        for i, (imgs, _) in enumerate(dataloader):

            # Adversarial ground truths
            valid = torch.ones(imgs.shape[0], 1)
            fake = torch.zeros(imgs.shape[0], 1)
            if cuda:
                imgs = imgs.cuda()
                valid = valid.cuda()
                fake = fake.cuda()

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = torch.from_numpy(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).float()
            if cuda:
                z = z.cuda()

            # Generate a batch of images
            gen_imgs = generator(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            print("[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}]".format(
                epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))

            batches_done = epoch * len(dataloader) + i
            if batches_done % opt.sample_interval == 0:
                save_image(gen_imgs.data[:25], "{}/{}.png".format(images_path, batches_done), nrow=5, normalize=True)

        if (epoch + 1) % opt.ckp_interval == 0:
            torch.save(generator.state_dict(), "{}/G_{}.pth".format(ckps_path, epoch))
            torch.save(discriminator.state_dict(), "{}/D_{}.pth".format(ckps_path, epoch))

    torch.save(generator.state_dict(), "{}/G_last.pth".format(ckps_path))
    torch.save(discriminator.state_dict(), "{}/D_last.pth".format(ckps_path))
Exemplo n.º 14
0
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
        
monet_generator = Generator().to(device)
photo_generator = Generator().to(device)
monet_discriminator = Discriminator().to(device)
photo_discriminator = Discriminator().to(device)

monet_generator = monet_generator.apply(weights_init)
photo_generator = photo_generator.apply(weights_init)
monet_discriminator = monet_discriminator.apply(weights_init)
photo_discriminator = photo_discriminator.apply(weights_init)


n_epoch = 50
BATCH_SIZE = 8
LAMBDA=10
lr = 2e-4
save = True

trainset = ImageLoader(MONET_PATH,PHOTO_PATH)
loader = DataLoader(trainset,BATCH_SIZE,shuffle=True)
len_loader = len(loader)
i=0

monet_generator.train()
class GAN:
    def __init__(self):
        self.device = torch.device('cuda')

        self.generator = Generator(100, 1, 64).to(self.device)
        self.discriminator = Discriminator(1, 64).to(self.device)
        self.generator.apply(self.weights_init)
        self.discriminator.apply(self.weights_init)
        print(self.generator)
        print(self.discriminator)

        self.gen_optimizer = optim.Adam(self.generator.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))
        self.dis_optimizer = optim.Adam(self.discriminator.parameters(),
                                        lr=0.0002,
                                        betas=(0.5, 0.999))

        self.loss_criterion = nn.BCELoss()

        train_dir = os.path.join(os.getcwd(), 'train', 'Imagenet')
        test_dir = os.path.join(os.getcwd(), 'test')
        # self.dataset = datasets.MNIST(train_dir, download=True,
        #                    transform=transforms.Compose([
        #                        transforms.Resize(64),
        #                        transforms.ToTensor(),
        #                        transforms.Normalize((0.5,), (0.5,)),
        #                    ]))

        self.dataset = datasets.ImageFolder(root=train_dir,
                                            transform=transforms.Compose([
                                                transforms.Resize(64),
                                                transforms.CenterCrop(64),
                                                transforms.ToTensor(),
                                                transforms.Normalize(
                                                    (0.5, 0.5, 0.5),
                                                    (0.5, 0.5, 0.5)),
                                            ]))
        self.data_loader = DataLoader(self.dataset,
                                      batch_size=128,
                                      shuffle=True,
                                      num_workers=2)
        self.writer = SummaryWriter(comment='dcgan_imagenet')

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, mean=0.0, std=0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, mean=1.0, std=0.02)
            nn.init.constant_(m.bias.data, val=0.0)

    def train(self):
        fixed_noise = torch.randn(size=(128, 100, 1, 1), device=self.device)
        real_label = 1
        fake_label = 0
        for epoch in range(100):
            real_losses = []
            fake_losses = []
            gen_losses = []
            dis_losses = []
            for i, data in enumerate(self.data_loader, 0):
                self.generator.zero_grad()
                self.discriminator.zero_grad()
                self.gen_optimizer.zero_grad()
                self.dis_optimizer.zero_grad()

                real_images = data[0].to(self.device)
                batch_size = real_images.size(0)

                noise = torch.randn(size=(batch_size, 100, 1, 1),
                                    device=self.device)
                fake_images = self.generator(noise)

                real_labels = torch.full(size=(batch_size, ),
                                         fill_value=real_label,
                                         device=self.device)
                fake_labels = torch.full(size=(batch_size, ),
                                         fill_value=fake_label,
                                         device=self.device)

                real_output = self.discriminator(real_images)
                fake_output = self.discriminator(fake_images.detach())

                real_loss = self.loss_criterion(real_output, real_labels)
                fake_loss = self.loss_criterion(fake_output, fake_labels)
                real_loss.backward()
                fake_loss.backward()

                self.dis_optimizer.step()
                dis_loss = real_loss + fake_loss
                real_losses.append(real_loss.item())
                fake_losses.append(fake_loss.item())
                dis_losses.append(dis_loss.item())

                fake_output = self.discriminator(fake_images)
                gen_loss = self.loss_criterion(fake_output, real_labels)
                gen_loss.backward()
                self.gen_optimizer.step()
                gen_losses.append(gen_loss.item())

                print(
                    "Epoch : %d, Iteration : %d, Generator loss : %0.3f, Discriminator loss : %0.3f, Real label loss : %0.3f, Fake label loss : %0.3f"
                    % (epoch, i, gen_loss.item(), dis_loss.item(),
                       real_loss.item(), fake_loss.item()))

            real_loss = np.mean(np.array(real_losses))
            fake_loss = np.mean(np.array(fake_losses))
            dis_loss = np.mean(np.array(dis_losses))
            gen_loss = np.mean(np.array(gen_losses))

            fake_images = self.generator(fixed_noise).detach()
            self.writer.add_image("images",
                                  vutils.make_grid(fake_images.data[:128]),
                                  epoch)

            self.writer.add_scalar('Real Loss', real_loss, epoch)
            self.writer.add_scalar('Fake Loss', fake_loss, epoch)
            self.writer.add_scalar('Discriminator Loss', dis_loss, epoch)
            self.writer.add_scalar('Generator Loss', gen_loss, epoch)

            self.save_network(epoch)

    def save_network(self, epoch):
        path = os.path.join(os.getcwd(), 'models')
        torch.save(self.generator.state_dict(),
                   '%s/generator_epoch_%d.pth' % (path, epoch))
        torch.save(self.discriminator.state_dict(),
                   '%s/discriminator_epoch_%d.pth' % (path, epoch))
Exemplo n.º 16
0
def train_gan(data: DataLoader,
              latent_dim: int = 10,
              n_dis_hn: int = 25,
              n_gen_hn: int = 15,
              n_epochs: int = 10,
              lr: float = 0.01,
              early_stop: bool = False):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    gen_model = Generator(latent_dim, n_gen_hn).to(device)
    gen_model.apply(init_weights)
    dis_model = Discriminator(n_dis_hn).to(device)
    dis_model.apply(init_weights)

    gen_optim = optim.Adam(gen_model.parameters(), lr=lr)
    dis_optim = optim.Adam(dis_model.parameters(), lr=lr)

    criterion = nn.BCELoss()

    d_losses = []
    g_losses = []
    D_vals = []

    prev_loss = 0
    prev_loss_counter = 0
    min_gen_loss = 100
    best_epoch = 0
    stopped = False
    for epoch in range(n_epochs):
        d_epoch_loss = 0
        g_epoch_loss = 0
        iter = 0  # count batches per epoch
        for x, y in data:
            # forward pass through discriminator on only-real data
            dis_model.zero_grad()
            output, loss_real = update_model(dis_model, criterion, x, y)
            D_x = output.mean().item()

            # forward pass with fake data
            noise = make_noise(x.shape[0], latent_dim)
            fake_x = gen_model(noise)
            label = torch.Tensor([FAKE_LABEL] * x.shape[0])
            output, loss_fake = update_model(dis_model, criterion,
                                             fake_x.detach(), label)

            # combine real and fake losses
            loss = loss_real + loss_fake
            d_epoch_loss += loss.item()
            D_g_z1 = output.mean().item()

            # update discriminator
            dis_optim.step()

            # now update generator based on discriminator performance
            gen_model.zero_grad()
            label = torch.Tensor([REAL_LABEL] * x.shape[0])
            ## discriminator updated, so check how it's doing now
            output, gen_loss = update_model(dis_model, criterion, fake_x,
                                            label)
            gen_optim.step()
            g_epoch_loss += gen_loss.item()
            D_g_z2 = output.mean().item()

            d_losses.append(loss.item())
            g_losses.append(gen_loss.item())
            D_vals.append([D_x, D_g_z1, D_g_z2])

            iter += 1

        if epoch % 100 == 0:
            print(
                f'Epoch {epoch} DLoss: {d_epoch_loss}, GLoss: {g_epoch_loss}')

        # early stopping
        if early_stop:
            if g_epoch_loss < min_gen_loss and epoch > 0:
                min_gen_loss = g_epoch_loss
                best_epoch = epoch
                torch.save(gen_model, 'best_gen_model.torch')
            elif epoch - best_epoch > 1000:
                print(f'Model not learning, {epoch - best_epoch} epochs '
                      f'since best epoch {best_epoch} ({min_gen_loss})')
                stopped = True
                break
            elif g_epoch_loss > 500:
                print(f'Sudden loss explosion: E{epoch}: {g_epoch_loss}')
                stopped = True
                break

            # handle when discriminator loss is not changing
            if round(d_epoch_loss, 3) != prev_loss:
                prev_loss = round(d_epoch_loss, 3)
                prev_loss_counter = 0
            else:
                prev_loss_counter += 1
            if prev_loss_counter > 10:
                print(f'Model converged, epoch {epoch}')
                break

    plt.plot(d_losses, label='discriminator', alpha=0.6)
    plt.plot(g_losses, label='generator', alpha=0.6)
    plt.vlines(best_epoch * iter, 0, plt.ylim()[1], color='red')
    plt.legend()
    plt.xlabel('iterations')
    plt.ylabel('loss')
    plt.savefig('../MyHomeFolder/training_losses.png', bbox_inches='tight')
    plt.clf()

    plt.plot([x[0] for x in D_vals], label='D_x', alpha=0.6)
    plt.plot([x[1] for x in D_vals], label='D_g_z1', alpha=0.6)
    plt.plot([x[2] for x in D_vals], label='D_g_z2', alpha=0.6)
    plt.ylim(-0.1, 1.1)
    plt.hlines(0.5, 0, plt.xlim()[1], color='red')
    plt.legend()
    plt.xlabel('iterations')
    plt.ylabel('D(x)')
    plt.savefig('../MyHomeFolder/d_values.png', bbox_inches='tight')
    plt.clf()

    if stopped:
        gen_model = torch.load('best_gen_model.torch')

    return gen_model, d_losses, g_losses
Exemplo n.º 17
0
class SRGAN():
    def __init__(self):
        self.dataset = DataLoader(data_path=data_path,
                                  transform=transforms.Compose([ToTensor()]))
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=num_workers,
                                                       drop_last=True)
        checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path)
        if checkpoint == None:
            self.G = Generator().to(device)
            self.D = Discriminator().to(device)
            self.G.apply(initital_network_weights).to(device)
            self.D.apply(initital_network_weights).to(device)
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.epoch = 0
        else:
            self.G = Generator().to(device)
            self.D = Discriminator().to(device)
            self.G.load_state_dict(checkpoint['G'])
            self.D.load_state_dict(checkpoint['D'])
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_D'])
            self.epoch = checkpoint['epoch']

        self.label_criterion = nn.BCEWithLogitsLoss().to(device)
        self.tag_criterion = nn.MultiLabelSoftMarginLoss().to(device)

    def load_checkpoint(self, model_dir):
        models_path = utils.read_newest_model(model_dir)
        if len(models_path) == 0:
            return None, None
        models_path.sort()
        new_model_path = os.path.join(model_dump_path, models_path[-1])
        if torch.cuda.is_available():
            checkpoint = torch.load(new_model_path)
        else:
            checkpoint = torch.load(
                new_model_path,
                map_location='cuda' if torch.cuda.is_available() else 'cpu')
        return checkpoint, new_model_path

    def train(self):
        iteration = -1
        label = Variable(torch.FloatTensor(batch_size, 1)).to(device)
        while self.epoch <= max_epoch:
            adjust_learning_rate(self.optimizer_G, iteration)
            adjust_learning_rate(self.optimizer_D, iteration)
            for i, (anime_tag, anime_img) in enumerate(self.data_loader):
                iteration += 1
                if anime_img.shape[0] != batch_size:
                    continue
                anime_img = Variable(anime_img).to(device)
                anime_tag = Variable(torch.FloatTensor(anime_tag)).to(device)
                # D : G = 2 : 1
                # 1. Training D
                # 1.1. use real image for discriminating
                self.D.zero_grad()
                label_p, tag_p = self.D(anime_img)
                label.data.fill_(1.0)

                # 1.2. real image's loss
                real_label_loss = self.label_criterion(label_p, label)
                real_tag_loss = self.tag_criterion(tag_p, anime_tag)
                real_loss_sum = real_label_loss * lambda_adv / 2.0 + real_tag_loss * lambda_adv / 2.0
                real_loss_sum.backward()

                # 1.3. use fake image for discriminating
                g_noise, fake_tag = utils.fake_generator(
                    batch_size, noise_size, device)
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat).detach()
                fake_label_p, fake_tag_p = self.D(fake_img)
                label.data.fill_(.0)

                # 1.4. fake image's loss
                fake_label_loss = self.label_criterion(fake_label_p, label)
                fake_tag_loss = self.tag_criterion(fake_tag_p, fake_tag)
                fake_loss_sum = fake_label_loss * lambda_adv / 2.0 + fake_tag_loss * lambda_adv / 2.0
                fake_loss_sum.backward()

                # 1.5. gradient penalty
                # https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py
                alpha_size = [1] * anime_img.dim()
                alpha_size[0] = anime_img.size(0)
                alpha = torch.rand(alpha_size).to(device)
                x_hat = Variable(alpha * anime_img.data + (1 - alpha) * \
                                 (anime_img.data + 0.5 * anime_img.data.std() * Variable(torch.rand(anime_img.size())).to(device)),
                                 requires_grad=True).to(device)
                pred_hat, pred_tag = self.D(x_hat)
                gradients = grad(outputs=pred_hat,
                                 inputs=x_hat,
                                 grad_outputs=torch.ones(
                                     pred_hat.size()).to(device),
                                 create_graph=True,
                                 retain_graph=True,
                                 only_inputs=True)[0].view(x_hat.size(0), -1)
                gradient_penalty = lambda_gp * (
                    (gradients.norm(2, dim=1) - 1)**2).mean()
                # gradient_penalty.requires_grad = True
                gradient_penalty = Variable(gradient_penalty,
                                            requires_grad=True)
                gradient_penalty.backward()

                # 1.6. update optimizer
                self.optimizer_D.step()

                # 2. Training G
                # 2.1. generate fake image
                self.G.zero_grad()
                g_noise, fake_tag = utils.fake_generator(
                    batch_size, noise_size, device)
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat)
                fake_label_p, fake_tag_p = self.D(fake_img)
                label.data.fill_(1.0)

                # 2.2. calc loss
                label_loss_g = self.label_criterion(fake_label_p, label)
                tag_loss_g = self.tag_criterion(fake_tag_p, fake_tag)
                loss_g = label_loss_g * lambda_adv / 2.0 + tag_loss_g * lambda_adv / 2.0
                loss_g.backward()

                # 2.2. update optimizer
                self.optimizer_G.step()

                if iteration % verbose_T == 0:
                    print('The iteration is now %d' % iteration)
                    print('The loss is %.4f, %.4f, %.4f, %.4f' %
                          (real_loss_sum, fake_loss_sum, gradient_penalty,
                           loss_g))
                    vutils.save_image(
                        anime_img.data.view(batch_size, 3, anime_img.size(2),
                                            anime_img.size(3)),
                        os.path.join(
                            tmp_path, 'real_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    g_noise, fake_tag = utils.fake_generator(
                        batch_size, noise_size, device)
                    fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                    fake_img = self.G(fake_feat)
                    vutils.save_image(
                        fake_img.data.view(batch_size, 3, anime_img.size(2),
                                           anime_img.size(3)),
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8))))
            # dump checkpoint
            torch.save(
                {
                    'epoch': self.epoch,
                    'D': self.D.state_dict(),
                    'G': self.G.state_dict(),
                    'optimizer_D': self.optimizer_D.state_dict(),
                    'optimizer_G': self.optimizer_G.state_dict(),
                }, '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                 str(self.epoch).zfill(4)))
            self.epoch += 1
#  to mean=0, stdev=0.2.
netG.apply(il.weights_init)

# Print the model
print(netG)

# Create the Discriminator
netD = Discriminator(cf.ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (cf.ngpu > 1):
    netD = nn.DataParallel(netD, list(range(cf.ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(il.weights_init)

# Print the model
print(netD)

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, cf.nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0
Exemplo n.º 19
0
import torch.optim as optim

from dataloader import loader
from discriminator import Discriminator
from generator import Generator
from parameters import d_conv_dim, z_size, g_conv_dim, beta1, beta2, lr_d, lr_g
from train import train
from utils import display_images, weights_init_normal, gpu_check, load_samples, view_samples

display_images(loader)

D = Discriminator(d_conv_dim)
G = Generator(z_size=z_size, conv_dim=g_conv_dim)

D.apply(weights_init_normal)
G.apply(weights_init_normal)

train_on_gpu = gpu_check()

if not train_on_gpu:
    print('No GPU found. Please use a GPU to train your neural network.')
else:
    print('Training on GPU!')

d_optimizer = optim.Adam(D.parameters(), lr_d, [beta1, beta2])
g_optimizer = optim.Adam(G.parameters(), lr_g, [beta1, beta2])

n_epochs = 30

losses = train(D, d_optimizer, G, g_optimizer, n_epochs=n_epochs)
Exemplo n.º 20
0
def train(opt: Options):
    real_label = 1
    fake_label = 0

    netG = Generator(opt)
    netD = Discriminator(opt)
    print(netG)
    print(netD)

    netG.apply(weights_init_g)
    netD.apply(weights_init_d)

    # summary(netD, (opt.c_dim, opt.x_dim, opt.y_dim))

    dataloader = load_data(opt.data_root, opt.x_dim, opt.y_dim, opt.batch_size, opt.workers)

    x, y, r = get_coordinates(x_dim=opt.x_dim, y_dim=opt.y_dim, scale=opt.scale, batch_size=opt.batch_size)

    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))

    criterion = nn.BCELoss()
    # criterion = nn.L1Loss()

    noise = torch.FloatTensor(opt.batch_size, opt.z_dim)
    ones = torch.ones(opt.batch_size, opt.x_dim * opt.y_dim, 1)
    input_ = torch.FloatTensor(opt.batch_size, opt.c_dim, opt.x_dim, opt.y_dim)
    label = torch.FloatTensor(opt.batch_size, 1)

    input_ = Variable(input_)
    label = Variable(label)
    noise = Variable(noise)

    if opt.use_cuda:
        netG = netG.cuda()
        netD = netD.cuda()
        x = x.cuda()
        y = y.cuda()
        r = r.cuda()
        ones = ones.cuda()
        criterion = criterion.cuda()
        input_ = input_.cuda()
        label = label.cuda()
        noise = noise.cuda()

    noise.data.normal_()
    fixed_seed = torch.bmm(ones, noise.unsqueeze(1))

    def _update_discriminator(data):
        # for p in netD.parameters():
        #     p.requires_grad = True  # to avoid computation
        netD.zero_grad()
        real_cpu, _ = data
        input_.data.copy_(real_cpu)
        label.data.fill_(real_label-0.1)  # use smooth label for discriminator

        output = netD(input_)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.data.mean()

        # train with fake
        noise.data.normal_()
        seed = torch.bmm(ones, noise.unsqueeze(1))

        fake = netG(x, y, r, seed)
        label.data.fill_(fake_label)
        output = netD(fake.detach())  # add ".detach()" to avoid backprop through G
        errD_fake = criterion(output, label)
        errD_fake.backward()  # gradients for fake/real will be accumulated
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake
        optimizerD.step()  # .step() can be called once the gradients are computed

        return fake, D_G_z1, errD, D_x

    def _update_generator(fake):
        # for p in netD.parameters():
        #     p.requires_grad = False  # to avoid computation
        netG.zero_grad()

        label.data.fill_(real_label)  # fake labels are real for generator cost

        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()  # True if backward through the graph for the second time
        D_G_z2 = output.data.mean()
        optimizerG.step()

        return D_G_z2, errG

    def _save_model(epoch):
        os.makedirs(opt.models_root, exist_ok=True)
        if epoch % 1 == 0:
            torch.save(netG.state_dict(), os.path.join(opt.models_root, "G-cppn-wgan-anime_{}.pth".format(epoch)))
            torch.save(netD.state_dict(), os.path.join(opt.models_root, "D-cppn-wgan-anime_{}.pth".format(epoch)))

    def _log(i, epoch, errD, errG, D_x, D_G_z1, D_G_z2, delta_time):
        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f Elapsed %.2f s'
              % (epoch, opt.iterations, i, len(dataloader), errD.data.item(), errG.data.item(), D_x, D_G_z1, D_G_z2,
                 delta_time))

    def _save_images(i, epoch):
        os.makedirs(opt.images_root, exist_ok=True)
        if i % 100 == 0:
            fake = netG(x, y, r, fixed_seed)
            fname = os.path.join(opt.images_root, "fake_samples_{:02}-{:04}.png".format(epoch, i))
            vutils.save_image(fake.data[0:64, :, :, :], fname, nrow=8)

    def _start():
        print("Start training")
        for epoch in range(opt.iterations):
            for i, data in enumerate(dataloader, 0):
                start_iter = time.time()

                fake, D_G_z1, errD, D_x = _update_discriminator(data)
                D_G_z2, errG = _update_generator(fake)

                end_iter = time.time()

                _log(i, epoch, errD, errG, D_x, D_G_z1, D_G_z2, end_iter - start_iter)
                _save_images(i, epoch)
            _save_model(epoch)

    _start()
Exemplo n.º 21
0
def main():
    # GPU を使えるかどうかの確認
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("\n==================================")
    print("DCGAN is avtivating at", device, "!!")
    print("==================================")
    start = time.time()


    # Create the generator
    netG = Generator(opt).to(device)
    netD = Discriminator(opt).to(device)


    print(netD)
    print(netG)

    #  重みをランダムに初期化 mean=0, stdev=0.2.
    netG.apply(weights_init)
    netD.apply(weights_init)

    # Loss function
    criterion = nn.BCELoss()

    # Optimizers に Adam をセット
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.b1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.b1, 0.999))
    optimizers = {"optimizerG" : optimizerG, "optimizerD" : optimizerD}
    
    # MNIST をロード
    # data = load_MNIST(img_size=opt.img_size, batch_size=opt.batch_size, img_path=opt.img_path)     # data["train], data["test"] にそれぞれ DataLoader が格納
    # data = load_CIFAR10(img_size=opt.img_size, batch_size=opt.batch_size, img_path=opt.img_path)
    # data = load_CelebA(img_size=opt.img_size, batch_size=opt.batch_size, img_path=opt.img_path)
    data = load_local_CelebA(img_size=opt.img_size, batch_size=opt.batch_size, img_path="F:/datasets/GAN/data/celebA")
    
    

    with tqdm(range(opt.n_epochs), leave=False) as pbar:
        # エラー推移
        result = {}
        result["log_loss_G"] = []
        result["log_loss_D"] = []
        result["log_d_out_real"] = []
        # Discriminator の生の出力値をモニター
        result["log_d_out_fake1"] = []  # Generator が重み更新直後に偽画像を D に入れた時の出力値(基本0.5に近い値)
        result["log_d_out_fake2"] = []  # Discriminator が重み更新直後に偽画像を D に入れた時の出力値(基本0に近い値)
        for epoch in pbar:
            log = train(
                    loader_train=data["train"], generator=netG,
                    discriminator=netD, optimizer=optimizers, 
                    loss_fn=criterion, device=device, opt=opt,
                    epoch=epoch, result=result)

            # 1エポック終了したら, 1エポックでlossの平均を取る
            result["log_loss_G"].extend(log['loss_g'])
            result["log_loss_D"].extend(log['loss_d'])
            result["log_d_out_real"].extend(log['real'])
            result["log_d_out_fake1"].extend(log['fake1'])
            result["log_d_out_fake2"].extend(log['fake2'])

            # 1エポック毎に生成画像とネットワークの保存
            save(epoch=epoch, generate_img=log['fake_img_tensor'], opt=opt)

            # プログレスバー情報更新
            if opt.plot_epoch:
                pbar.set_postfix(OrderedDict(
                    D_out_REAL="{:.4f}".format(result["log_d_out_real"][-1]),
                    D_out_FAKE1="{:.4f}".format(result["log_d_out_fake1"][-1]),
                    D_out_FAKE2="{:.4f}".format(result["log_d_out_fake2"][-1])
                    ))
            else:
                pbar.set_postfix(OrderedDict(
                    D_out_REAL="{:.4f}".format(sum(log['real'])/len(log['real'])),
                    D_out_FAKE1="{:.4f}".format(sum(log['fake1'])/len(log['fake1'])),
                    D_out_FAKE2="{:.4f}".format(sum(log['fake2'])/len(log['fake2']))
                    ))

    elapsed_time = time.time() - start
    print ("time:{0}".format(datetime.timedelta(seconds=elapsed_time)))
    plot_loss(result, opt)
    plot_d_out(result, opt)

    print("--END--")
Exemplo n.º 22
0
def semi_main(options):
    print('\nSemi-Supervised Learning!\n')

    # 1. Make sure the options are valid argparse CLI options indeed
    assert isinstance(options, argparse.Namespace)

    # 2. Set up the logger
    logging.basicConfig(level=str(options.loglevel).upper())

    # 3. Make sure the output dir `outf` exists
    _check_out_dir(options)

    # 4. Set the random state
    _set_random_state(options)

    # 5. Configure CUDA and Cudnn, set the global `device` for PyTorch
    device = _configure_cuda(options)

    # 6. Prepare the datasets and split it for semi-supervised learning
    if options.dataset != 'cifar10':
        raise NotImplementedError(
            'Semi-supervised learning only support CIFAR10 dataset at the moment!'
        )
    test_data_loader, semi_data_loader, train_data_loader = _prepare_semi_dataset(
        options)

    # 7. Set the parameters
    ngpu = int(options.ngpu)  # num of GPUs
    nz = int(
        options.nz)  # size of latent vector, also the number of the generators
    ngf = int(options.ngf)  # depth of feature maps through G
    ndf = int(options.ndf)  # depth of feature maps through D
    nc = int(options.nc
             )  # num of channels of the input images, 3 indicates color images
    M = int(options.mcmc)  # num of SGHMC chains run concurrently
    nd = int(options.nd)  # num of discriminators
    nsetz = int(options.nsetz)  # num of noise batches

    # 8. Special preparations for Bayesian GAN for Generators

    # In order to inject the SGHMAC into the training process, instead of pause the gradient descent at
    # each training step, which can be easily defined with static computation graph(Tensorflow), in PyTorch,
    # we have to move the Generator Sampling to the very beginning of the whole training process, and use
    # a trick that initializing all of the generators explicitly for later usages.
    Generator_chains = []
    for _ in range(nsetz):
        for __ in range(M):
            netG = Generator(ngpu, nz, ngf, nc).to(device)
            netG.apply(weights_init)
            Generator_chains.append(netG)

    logging.info(
        f'Showing the first generator of the Generator chain: \n {Generator_chains[0]}\n'
    )

    # 9. Special preparations for Bayesian GAN for Discriminators
    assert options.dataset == 'cifar10', 'Semi-supervised learning only support CIFAR10 dataset at the moment!'

    num_class = 10 + 1

    # To simplify the implementation we only consider the situation of 1 discriminator
    # if nd <= 1:
    #     netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    #     netD.apply(weights_init)
    # else:
    # Discriminator_chains = []
    # for _ in range(nd):
    #     for __ in range(M):
    #         netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    #         netD.apply(weights_init)
    #         Discriminator_chains.append(netD)

    netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    netD.apply(weights_init)
    logging.info(f'Showing the Discriminator model: \n {netD}\n')

    # 10. Loss function
    criterion = nn.CrossEntropyLoss()
    all_criterion = ComplementCrossEntropyLoss(except_index=0, device=device)

    # 11. Set up optimizers
    optimizerG_chains = [
        optim.Adam(netG.parameters(),
                   lr=options.lr,
                   betas=(options.beta1, 0.999)) for netG in Generator_chains
    ]

    # optimizerD_chains = [
    #     optim.Adam(netD.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) for netD in Discriminator_chains
    # ]
    optimizerD = optim.Adam(netD.parameters(),
                            lr=options.lr,
                            betas=(options.beta1, 0.999))
    import math
    # 12. Set up the losses for priors and noises
    gprior = PriorLoss(prior_std=1., total=500.)
    gnoise = NoiseLoss(params=Generator_chains[0].parameters(),
                       device=device,
                       scale=math.sqrt(2 * options.alpha / options.lr),
                       total=500.)
    dprior = PriorLoss(prior_std=1., total=50000.)
    dnoise = NoiseLoss(params=netD.parameters(),
                       device=device,
                       scale=math.sqrt(2 * options.alpha * options.lr),
                       total=50000.)

    gprior.to(device=device)
    gnoise.to(device=device)
    dprior.to(device=device)
    dnoise.to(device=device)

    # In order to let G condition on a specific noise, we attach the noise to a fixed Tensor
    fixed_noise = torch.FloatTensor(options.batchSize, options.nz, 1,
                                    1).normal_(0, 1).to(device=device)
    inputT = torch.FloatTensor(options.batchSize, 3, options.imageSize,
                               options.imageSize).to(device=device)
    noiseT = torch.FloatTensor(options.batchSize, options.nz, 1,
                               1).to(device=device)
    labelT = torch.FloatTensor(options.batchSize).to(device=device)
    real_label = 1
    fake_label = 0

    # 13. Transfer all the tensors and modules to GPU if applicable
    # for netD in Discriminator_chains:
    #     netD.to(device=device)
    netD.to(device=device)

    for netG in Generator_chains:
        netG.to(device=device)
    criterion.to(device=device)
    all_criterion.to(device=device)

    # ========================
    # === Training Process ===
    # ========================

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    stats = []
    iters = 0

    try:
        print("\nStarting Training Loop...\n")
        for epoch in range(options.niter):
            top1 = Metrics()
            for i, data in enumerate(train_data_loader, 0):
                # ##################
                # Train with real
                # ##################
                netD.zero_grad()
                real_cpu = data[0].to(device)
                batch_size = real_cpu.size(0)
                # label = torch.full((batch_size,), real_label, device=device)

                inputT.resize_as_(real_cpu).copy_(real_cpu)
                labelT.resize_(batch_size).fill_(real_label)

                inputv = torch.autograd.Variable(inputT)
                labelv = torch.autograd.Variable(labelT)

                output = netD(inputv)
                errD_real = all_criterion(output)
                errD_real.backward()
                D_x = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()

                # ##################
                # Train with fake
                # ##################
                fake_images = []
                for i_z in range(nsetz):
                    noiseT.resize_(batch_size, nz, 1, 1).normal_(
                        0, 1)  # prior, sample from N(0, 1) distribution
                    noisev = torch.autograd.Variable(noiseT)
                    for m in range(M):
                        idx = i_z * M + m
                        netG = Generator_chains[idx]
                        _fake = netG(noisev)
                        fake_images.append(_fake)
                # output = torch.stack(fake_images)
                fake = torch.cat(fake_images)
                output = netD(fake.detach())

                labelv = torch.autograd.Variable(
                    torch.LongTensor(fake.data.shape[0]).to(
                        device=device).fill_(fake_label))
                errD_fake = criterion(output, labelv)
                errD_fake.backward()
                D_G_z1 = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()

                # ##################
                # Semi-supervised learning
                # ##################
                for ii, (input_sup, target_sup) in enumerate(semi_data_loader):
                    input_sup, target_sup = input_sup.to(
                        device=device), target_sup.to(device=device)
                    break
                input_sup_v = input_sup.to(device=device)
                target_sup_v = (target_sup + 1).to(device=device)
                output_sup = netD(input_sup_v)
                err_sup = criterion(output_sup, target_sup_v)
                err_sup.backward()
                pred1 = accuracy(output_sup.data, target_sup + 1,
                                 topk=(1, ))[0]
                top1.update(value=pred1.item(), N=input_sup.size(0))

                errD_prior = dprior(netD.parameters())
                errD_prior.backward()
                errD_noise = dnoise(netD.parameters())
                errD_noise.backward()
                errD = errD_real + errD_fake + err_sup + errD_prior + errD_noise
                optimizerD.step()

                # ##################
                # Sample and construct generator(s)
                # ##################
                for netG in Generator_chains:
                    netG.zero_grad()
                labelv = torch.autograd.Variable(
                    torch.FloatTensor(fake.data.shape[0]).to(
                        device=device).fill_(real_label))
                output = netD(fake)
                errG = all_criterion(output)

                for netG in Generator_chains:
                    errG = errG + gprior(netG.parameters())
                    errG = errG + gnoise(netG.parameters())
                errG.backward()
                D_G_z2 = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()
                for optimizerG in optimizerG_chains:
                    optimizerG.step()

                # ##################
                # Evaluate testing accuracy
                # ##################
                # Pause and compute the test accuracy after every 10 times of the notefreq
                if iters % 10 * int(options.notefreq) == 0:
                    # get test accuracy on train and test
                    netD.eval()
                    compute_test_accuracy(discriminator=netD,
                                          testing_data_loader=test_data_loader,
                                          device=device)
                    netD.train()

                # ##################
                # Note down
                # ##################
                # Report status for the current iteration
                training_status = f"[{epoch}/{options.niter}][{i}/{len(train_data_loader)}] Loss_D: {errD.item():.4f} " \
                                  f"Loss_G: " \
                                  f"{errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}" \
                                  f" | Acc {top1.value:.1f} / {top1.mean:.1f}"
                print(training_status)

                # Save samples to disk
                if i % int(options.notefreq) == 0:
                    vutils.save_image(
                        real_cpu,
                        f"{options.outf}/real_samples_epoch_{epoch:{0}{3}}_{i}.png",
                        normalize=True)
                    for _iz in range(nsetz):
                        for _m in range(M):
                            gidx = _iz * M + _m
                            netG = Generator_chains[gidx]
                            fake = netG(fixed_noise)
                            vutils.save_image(
                                fake.detach(),
                                f"{options.outf}/fake_samples_epoch_{epoch:{0}{3}}_{i}_z{_iz}_m{_m}.png",
                                normalize=True)

                    # Save Losses statistics for post-mortem
                    G_losses.append(errG.item())
                    D_losses.append(errD.item())
                    stats.append(training_status)

                    # # Check how the generator is doing by saving G's output on fixed_noise
                    # if (iters % 500 == 0) or ((epoch == options.niter - 1) and (i == len(data_loader) - 1)):
                    #     with torch.no_grad():
                    #         fake = netG(fixed_noise).detach().cpu()
                    #     img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                    iters += 1
            # TODO: find an elegant way to support saving checkpoints in Bayesian GAN context
    except Exception as e:
        print(e)

        # save training stats no matter what kind of errors occur in the processes
        _save_stats(statistic=G_losses, save_name='G_losses', options=options)
        _save_stats(statistic=D_losses, save_name='D_losses', options=options)
        _save_stats(statistic=stats,
                    save_name='Training_stats',
                    options=options)
Exemplo n.º 23
0
class Train():
    
    def __init__(self, args):
        root = args.root
        im_size = args.im_size
        batch_size = args.batch_size
        
        self.iterations = args.iter
        self.latent_dim = args.latent
    
        self.policy = 'color,translation'
    
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        self.G = Generator(self.latent_dim).to(self.device)
        self.D = Discriminator().to(self.device)
        
        self.G.apply(weights_init)
        self.D.apply(weights_init)
        
        self.G_optim = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.D_optim = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
        trans_list = [
            transforms.Resize((im_size, im_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]
        trans = transforms.Compose(trans_list)
        
        dataset = CustomImageDataset(root=root, transform=trans)
        self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
        
        self.mse = nn.MSELoss()
        
    def train(self):
        for i in tqdm(range(self.iterations)):
            real_imgs = next(iter(self.dataloader))
            real_imgs = real_imgs.to(self.device, non_blocking=True)
            real_imgs = DiffAugment(real_imgs, policy=self.policy)
            
            cur_batch_size = real_imgs.shape[0]
            noise = torch.Tensor(cur_batch_size, self.latent_dim, 1, 1).normal_(0, 1).to(self.device, non_blocking=True)
            gen_imgs = self.G(noise)
            
            fake_imgs = DiffAugment(gen_imgs, policy=self.policy)
            
            self.D.zero_grad()
            self.train_discriminator(real_imgs, label='real')
            self.train_discriminator(fake_imgs, label='fake')
            self.D_optim.step()
            
            self.G.zero_grad()
            pred = self.D(fake_imgs, label='fake')
            loss3 = -pred.mean()
            loss3.backward()
            self.G_optim.step()
        
            
            if i % 5000 == 0:
                model_path = 'model' + str(i) + '.pth' 
                torch.save(self.G.state_dict(), model_path)
        
        
        noise = torch.Tensor(cur_batch_size, self.latent_dim, 1, 1).normal_(0, 1).to(self.device, non_blocking=True)
        gen_imgs = self.G(noise)
        img = gen_imgs[0].cpu().detach().numpy().copy()
        plt.imshow(img.transpose(1, 2, 0))
        plt.show()
    
    def train_discriminator(self, image, label):
        if label == 'real':
            logits, cropimg512, decoded_img1, decoded_img2 = self.D(image, label)
            loss = F.relu(torch.rand_like(logits)*0.2 + 0.8 - logits).mean() + \
                self.mse(decoded_img1, F.interpolate(image, decoded_img1.shape[2])).sum() + \
                self.mse(decoded_img2, F.interpolate(cropimg512, decoded_img2.shape[2])).sum()
            loss.backward(retain_graph=True)

        else:
            logits = self.D(image, label)
            loss = F.relu(torch.rand_like(logits)*0.2 + 0.8 + logits).mean()
            loss.backward(retain_graph=True)
Exemplo n.º 24
0
def main(options):
    # 1. Make sure the options are valid argparse CLI options indeed
    assert isinstance(options, argparse.Namespace)

    # 2. Set up the logger
    logging.basicConfig(level=str(options.loglevel).upper())

    # 3. Make sure the output dir `outf` exists
    _check_out_dir(options)

    # 4. Set the random state
    _set_random_state(options)

    # 5. Configure CUDA and Cudnn, set the global `device` for PyTorch
    device = _configure_cuda(options)

    # 6. Prepare the datasets
    data_loader = _prepare_dataset(options)

    # 7. Set the parameters
    ngpu = int(options.ngpu)  # num of GPUs
    nz = int(options.nz)  # size of latent vector
    ngf = int(options.ngf)  # depth of feature maps through G
    ndf = int(options.ndf)  # depth of feature maps through D
    nc = int(options.nc
             )  # num of channels of the input images, 3 indicates color images

    # 8. Initialize (or load checkpoints for) the Generator model
    netG = Generator(ngpu, nz, ngf, nc).to(device)
    netG.apply(weights_init)
    if options.netG != '':
        logging.info(
            f'Found checkpoint of Generator at {options.netG}, loading from the saved model.\n'
        )
        netG.load_state_dict(torch.load(options.netG))
    logging.info(f'Showing the Generator model: \n {netG}\n')

    # 9. Initialize (or load checkpoints for) the Discriminator model
    netD = Discriminator(ngpu, ndf, nc).to(device)
    netD.apply(weights_init)
    if options.netD != '':
        logging.info(
            f'Found checkpoint of Discriminator at {options.netG}, loading from the saved model.\n'
        )
        netD.load_state_dict(torch.load(options.netD))
    logging.info(f'Showing the Discriminator model: \n {netD}\n')

    # ========================
    # === Training Process ===
    # ========================

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    stats = []
    iters = 0

    # Set the loss function to Binary Cross Entropy between the target and the output
    # See https://pytorch.org/docs/stable/nn.html#torch.nn.BCELoss
    criterion = nn.BCELoss()

    fixed_noise = torch.randn(options.batchSize, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=options.lr,
                            betas=(options.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=options.lr,
                            betas=(options.beta1, 0.999))

    print("\nStarting Training Loop...\n")
    for epoch in range(options.niter):
        for i, data in enumerate(data_loader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_cpu = data[0].to(device)
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size, ), real_label, device=device)

            output = netD(real_cpu)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            training_status = f"[{epoch}/{options.niter}][{i}/{len(data_loader)}] Loss_D: {errD.item():.4f} Loss_G: " \
                f"{errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}"
            print(training_status)

            if i % int(options.notefreq) == 0:
                vutils.save_image(
                    real_cpu,
                    f"{options.outf}/real_samples_epoch_{epoch:{0}{3}}_{i}.png",
                    normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(
                    fake.detach(),
                    f"{options.outf}/fake_samples_epoch_{epoch:{0}{3}}_{i}.png",
                    normalize=True)

                # Save Losses statistics for post-mortem
                G_losses.append(errG.item())
                D_losses.append(errD.item())
                stats.append(training_status)

                # Check how the generator is doing by saving G's output on fixed_noise
                if (iters % 500 == 0) or ((epoch == options.niter - 1) and
                                          (i == len(data_loader) - 1)):
                    with torch.no_grad():
                        fake = netG(fixed_noise).detach().cpu()
                    img_list.append(
                        vutils.make_grid(fake, padding=2, normalize=True))
                iters += 1

        # do checkpointing
        torch.save(netG.state_dict(), f"{options.outf}/netG_epoch_{epoch}.pth")
        torch.save(netG.state_dict(), f"{options.outf}/netD_epoch_{epoch}.pth")

    # save training stats
    _save_stats(statistic=G_losses, save_name='G_losses', options=options)
    _save_stats(statistic=D_losses, save_name='D_losses', options=options)
    _save_stats(statistic=stats, save_name='Training_stats', options=options)
Exemplo n.º 25
0
def main(args):
    # log hyperparameter
    print(args)

    # select device
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda: 0" if args.cuda else "cpu")

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # data loader
    transform = transforms.Compose([
        utils.Normalize(),
        utils.ToTensor()
    ])
    train_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_train_list,
        max_k=args.training_step,
        train=True,
        transform=transform
    )
    test_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_test_list,
        max_k=args.training_step,
        train=False,
        transform=transform
    )

    kwargs = {"num_workers": 4, "pin_memory": True} if args.cuda else {}
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                              shuffle=True, **kwargs)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                             shuffle=False, **kwargs)

    # model
    def generator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def discriminator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    g_model = Generator(args.upsample_mode, args.forward, args.backward, args.gen_sn, args.residual)
    g_model.apply(generator_weights_init)
    if args.data_parallel and torch.cuda.device_count() > 1:
        g_model = nn.DataParallel(g_model)
    g_model.to(device)

    if args.gan_loss != "none":
        d_model = Discriminator(args.dis_sn)
        d_model.apply(discriminator_weights_init)
        # if args.dis_sn:
        #     d_model = add_sn(d_model)
        if args.data_parallel and torch.cuda.device_count() > 1:
            d_model = nn.DataParallel(d_model)
        d_model.to(device)

    mse_loss = nn.MSELoss()
    adversarial_loss = nn.MSELoss()
    train_losses, test_losses = [], []
    d_losses, g_losses = [], []

    # optimizer
    g_optimizer = optim.Adam(g_model.parameters(), lr=args.lr,
                             betas=(args.beta1, args.beta2))
    if args.gan_loss != "none":
        d_optimizer = optim.Adam(d_model.parameters(), lr=args.d_lr,
                                 betas=(args.beta1, args.beta2))

    Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor

    # load checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint {}".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            g_model.load_state_dict(checkpoint["g_model_state_dict"])
            # g_optimizer.load_state_dict(checkpoint["g_optimizer_state_dict"])
            if args.gan_loss != "none":
                d_model.load_state_dict(checkpoint["d_model_state_dict"])
                # d_optimizer.load_state_dict(checkpoint["d_optimizer_state_dict"])
                d_losses = checkpoint["d_losses"]
                g_losses = checkpoint["g_losses"]
            train_losses = checkpoint["train_losses"]
            test_losses = checkpoint["test_losses"]
            print("=> load chekcpoint {} (epoch {})"
                  .format(args.resume, checkpoint["epoch"]))

    # main loop
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        # training..
        g_model.train()
        if args.gan_loss != "none":
            d_model.train()
        train_loss = 0.
        volume_loss_part = np.zeros(args.training_step)
        for i, sample in enumerate(train_loader):
            params = list(g_model.named_parameters())
            # pdb.set_trace()
            # params[0][1].register_hook(lambda g: print("{}.grad: {}".format(params[0][0], g)))
            # adversarial ground truths
            real_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(1.0), requires_grad=False)
            fake_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(0.0), requires_grad=False)

            v_f = sample["v_f"].to(device)
            v_b = sample["v_b"].to(device)
            v_i = sample["v_i"].to(device)
            g_optimizer.zero_grad()
            fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)

            # adversarial loss
            # update discriminator
            if args.gan_loss != "none":
                avg_d_loss = 0.
                avg_d_loss_real = 0.
                avg_d_loss_fake = 0.
                for k in range(args.n_d):
                    d_optimizer.zero_grad()
                    decisions = d_model(v_i)
                    d_loss_real = adversarial_loss(decisions, real_label)
                    fake_decisions = d_model(fake_volumes.detach())

                    d_loss_fake = adversarial_loss(fake_decisions, fake_label)
                    d_loss = d_loss_real + d_loss_fake
                    d_loss.backward()
                    avg_d_loss += d_loss.item() / args.n_d
                    avg_d_loss_real += d_loss_real / args.n_d
                    avg_d_loss_fake += d_loss_fake / args.n_d

                    d_optimizer.step()

            # update generator
            if args.gan_loss != "none":
                avg_g_loss = 0.
            avg_loss = 0.
            for k in range(args.n_g):
                loss = 0.
                g_optimizer.zero_grad()

                # adversarial loss
                if args.gan_loss != "none":
                    fake_decisions = d_model(fake_volumes)
                    g_loss = args.gan_loss_weight * adversarial_loss(fake_decisions, real_label)
                    loss += g_loss
                    avg_g_loss += g_loss.item() / args.n_g

                # volume loss
                if args.volume_loss:
                    volume_loss = args.volume_loss_weight * mse_loss(v_i, fake_volumes)
                    for j in range(v_i.shape[1]):
                        volume_loss_part[j] += mse_loss(v_i[:, j, :], fake_volumes[:, j, :]) / args.n_g / args.log_every
                    loss += volume_loss

                # feature loss
                if args.feature_loss:
                    feat_real = d_model.extract_features(v_i)
                    feat_fake = d_model.extract_features(fake_volumes)
                    for m in range(len(feat_real)):
                        loss += args.feature_loss_weight / len(feat_real) * mse_loss(feat_real[m], feat_fake[m])

                avg_loss += loss / args.n_g
                loss.backward()
                g_optimizer.step()

            train_loss += avg_loss

            # log training status
            subEpoch = (i + 1) // args.log_every
            if (i+1) % args.log_every == 0:
                print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch, (i+1) * args.batch_size, len(train_loader.dataset), 100. * (i+1) / len(train_loader),
                    avg_loss
                ))
                print("Volume Loss: ")
                for j in range(volume_loss_part.shape[0]):
                    print("\tintermediate {}: {:.6f}".format(
                        j+1, volume_loss_part[j]
                    ))

                if args.gan_loss != "none":
                    print("DLossReal: {:.6f} DLossFake: {:.6f} DLoss: {:.6f}, GLoss: {:.6f}".format(
                        avg_d_loss_real, avg_d_loss_fake, avg_d_loss, avg_g_loss
                    ))
                    d_losses.append(avg_d_loss)
                    g_losses.append(avg_g_loss)
                # train_losses.append(avg_loss)
                train_losses.append(train_loss.item() / args.log_every)
                print("====> SubEpoch: {} Average loss: {:.6f} Time {}".format(
                    subEpoch, train_loss.item() / args.log_every, time.asctime(time.localtime(time.time()))
                ))
                train_loss = 0.
                volume_loss_part = np.zeros(args.training_step)

            # testing...
            if (i + 1) % args.test_every == 0:
                g_model.eval()
                if args.gan_loss != "none":
                    d_model.eval()
                test_loss = 0.
                with torch.no_grad():
                    for i, sample in enumerate(test_loader):
                        v_f = sample["v_f"].to(device)
                        v_b = sample["v_b"].to(device)
                        v_i = sample["v_i"].to(device)
                        fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)
                        test_loss += args.volume_loss_weight * mse_loss(v_i, fake_volumes).item()

                test_losses.append(test_loss * args.batch_size / len(test_loader.dataset))
                print("====> SubEpoch: {} Test set loss {:4f} Time {}".format(
                    subEpoch, test_losses[-1], time.asctime(time.localtime(time.time()))
                ))

            # saving...
            if (i+1) % args.check_every == 0:
                print("=> saving checkpoint at epoch {}".format(epoch))
                if args.gan_loss != "none":
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict":  g_optimizer.state_dict(),
                                "d_model_state_dict": d_model.state_dict(),
                                "d_optimizer_state_dict": d_optimizer.state_dict(),
                                "d_losses": d_losses,
                                "g_losses": g_losses,
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                else:
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict": g_optimizer.state_dict(),
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                torch.save(g_model.state_dict(),
                           os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + ".pth"))

        num_subEpoch = len(train_loader) // args.log_every
        print("====> Epoch: {} Average loss: {:.6f} Time {}".format(
            epoch, np.array(train_losses[-num_subEpoch:]).mean(), time.asctime(time.localtime(time.time()))
        ))
class AdvGAN_Pretrain:
    def __init__(self,
                 device,
                 model,
                 model_num_labels,
                 box_min,
                 box_max):
        self.device = device
        self.model_num_labels = model_num_labels
        self.model = model
        self.box_min = box_min
        self.box_max = box_max

        self.netG = Generator().to(device)
        self.netDisc = Discriminator().to(device)

        # initialize all weights
        self.netG.apply(weights_init)
        self.netDisc.apply(weights_init)

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=1e-3)
        self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                            lr=1e-3)

        if not os.path.exists(models_path):
            os.makedirs(models_path)

    def train_batch(self, x, labels):
        # optimize D
        for i in range(1):
            perturbation = self.netG(x)

            # add a clipping trick
            adv_images = torch.clamp(perturbation, -0.3, 0.3) + x
            adv_images = torch.clamp(adv_images, self.box_min, self.box_max)

            self.optimizer_D.zero_grad()
            pred_real = self.netDisc(x)
            loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device))
            loss_D_real.backward()

            pred_fake = self.netDisc(adv_images.detach())
            loss_D_fake = F.mse_loss(pred_fake, torch.zeros_like(pred_fake, device=self.device))
            loss_D_fake.backward()
            loss_D_GAN = loss_D_fake + loss_D_real
            self.optimizer_D.step()

        # optimize G
        for i in range(1):
            self.optimizer_G.zero_grad()

            # cal G's loss in GAN
            pred_fake = self.netDisc(adv_images)
            loss_G_fake = F.mse_loss(pred_fake, torch.ones_like(pred_fake, device=self.device))
            loss_G_fake.backward(retain_graph=True)

            # calculate perturbation norm
            loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1))
            # loss_perturb = torch.max(loss_perturb - C, torch.zeros(1, device=self.device))

            # cal adv loss
            logits_model = self.model(adv_images)
            probs_model = F.softmax(logits_model, dim=1)
            onehot_labels = torch.eye(self.model_num_labels, device=self.device)[labels]

            # C&W loss function
            real = torch.sum(onehot_labels * probs_model, dim=1)
            other, _ = torch.max((1 - onehot_labels) * probs_model - onehot_labels * 10000, dim=1)
            zeros = torch.zeros_like(other)
            loss_adv_arr = torch.max(real - other, zeros)
            print(loss_adv_arr)
            print(loss_adv_arr.shape)
            loss_adv = torch.sum(loss_adv)

            # maximize cross_entropy loss
            # loss_adv = -F.mse_loss(logits_model, onehot_labels)
            # loss_adv = - F.cross_entropy(logits_model, labels)

            adv_lambda = 10
            pert_lambda = 1
            loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturb
            loss_G.backward()
            self.optimizer_G.step()

        return loss_D_GAN.item(), loss_G_fake.item(), loss_perturb.item(), loss_adv.item()

    def train(self, train_dataloader, epochs):
        writer = SummaryWriter(log_dir="visualization/orig_advgan/", comment='Original AdvGAN stats')
        
        for epoch in range(1, epochs+1):

            if epoch == 50:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=1e-4)
                self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                                    lr=1e-4)
                
            if epoch == 80:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=1e-5)
                self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                                    lr=1e-5)
                
            loss_D_sum = 0
            loss_G_fake_sum = 0
            loss_perturb_sum = 0
            loss_adv_sum = 0
            for i, data in enumerate(train_dataloader, start=0):
                images, labels = data
                images, labels = images.to(self.device), labels.to(self.device)

                loss_D_batch, loss_G_fake_batch, loss_perturb_batch, loss_adv_batch = \
                    self.train_batch(images, labels)
                loss_D_sum += loss_D_batch
                loss_G_fake_sum += loss_G_fake_batch
                loss_perturb_sum += loss_perturb_batch
                loss_adv_sum += loss_adv_batch

            # print statistics
            num_batch = len(train_dataloader)
            writer.add_scalar('discriminator_loss', loss_D_sum/num_batch, epoch)
            writer.add_scalar('generator_loss', loss_G_fake_sum/num_batch, epoch)
            writer.add_scalar('perturbation_loss', loss_perturb_sum/num_batch, epoch)
            writer.add_scalar('adversarial_loss', loss_adv_sum/num_batch, epoch)
            print("epoch %d:\nloss_D: %.5f, loss_G_fake: %.5f,\
             \nloss_perturb: %.5f, loss_adv: %.5f\n" %
                  (epoch, loss_D_sum/num_batch, loss_G_fake_sum/num_batch,
                   loss_perturb_sum/num_batch, loss_adv_sum/num_batch))

            # save generator
            if epoch%20==0:
                netG_file_name = models_path + 'netG_original_epoch_' + str(epoch) + '.pth'
                torch.save(self.netG.state_dict(), netG_file_name)
                netDisc_file_name = models_path + 'netDisc_original_epoch_' + str(epoch) + '.pth'
                torch.save(self.netDisc.state_dict(), netDisc_file_name)
        
        writer.close()
Exemplo n.º 27
0
    def __init__(self, generator: Generator, discriminator: Discriminator):
        super().__init__()

        self.generator = generator.apply(weights_init)
        self.discriminator = discriminator.apply(weights_init)
Exemplo n.º 28
0
epochs=200
save_imgs=100
lambda_cyc=10.
lambda_id=5.
dataset="monet2photo"
out_dir="results"
generator_AB=Generator(l=2,n_filters=8)
generator_BA=Generator(l=2,n_filters=8)
discriminator_A=Discriminator(h,w,c)
discriminator_B=Discriminator(h,w,c)
gan_loss=nn.MSELoss()
cycle_loss=nn.L1Loss()
ident_loss=nn.L1Loss()
generator_AB.apply(weight_init)
generator_BA.apply(weight_init)
discriminator_A.apply(weight_init)
discriminator_B.apply(weight_init)
if cuda:
	generator_AB=generator_AB.cuda()
	generator_BA=generator_BA.cuda()
	discriminator_A=discriminator_A.cuda()
	discriminator_B=discriminator_B.cuda()
	gan_loss.cuda()
	cycle_loss.cuda()
	ident_loss.cuda()

os.makedirs(out_dir,exist_ok=True)
patch=(1,h//2**4,w//2**4)
transforms_=[
	transforms.Resize((h,w),Image.BICUBIC),
	transforms.ToTensor(),
Exemplo n.º 29
0
def main():
    # データセットの準備
    make_data()
    train_img_list = make_datapath_list()
    mean, std = (0.5, ), (0.5, )
    train_dataset = GAN_Img_Dataset(train_img_list, ImageTransform(mean, std))
    batch_size = 64
    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True)

    # モデルの定義と重み初期化
    G = Generator(z_dim=20, image_size=64)
    D = Discriminator(image_size=64)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0, 0.02)
            nn.init.constant_(m.bias.data, 0)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    G.apply(weights_init)
    D.apply(weights_init)

    # decide device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device:\t", device)
    G.to(device)
    D.to(device)

    # define optimizer
    g_lr, d_lr = 0.0001, 0.0004
    g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [0, 0.9])
    d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [0, 0.9])

    # パラメタ
    z_dim = 20  # 乱数の次元

    G.train()
    D.train()
    torch.backends.cudnn.benchmark = True

    num_train_imgs = len(train_dataloader.dataset)
    iteration = 1
    logs = []

    # 学習 (300 Epochs)
    for epoch in range(300):
        t_epoch_start = time.time()
        epoch_g_loss = 0.0
        epoch_d_loss = 0.0

        for batch in train_dataloader:
            # バッチサイズ確認
            if batch.size()[0] == 1:
                continue

            # ラベルの準備
            batch = batch.to(device)
            batch_num = batch.size()[0]
            label_real = torch.full((batch_num, ), 1).to(device)
            label_fake = torch.full((batch_num, ), 0).to(device)

            # --- Discriminatorの学習 --- #
            # 真の画像を判定
            d_out_real, _, _ = D(batch)

            # 偽の画像を生成・判定
            input_z = torch.randn(batch_num, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images, _, _ = G(input_z)
            d_out_fake, _, _ = D(fake_images)

            # 損失を計算・パラメータ更新
            d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
            d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
            d_loss = d_loss_real + d_loss_fake

            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            # --- Generatorの学習 --- #
            input_z = torch.randn(batch_num, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images, _, _ = G(input_z)
            d_out_fake, _, _ = D(fake_images)

            # 損失を計算・パラメータ更新
            g_loss = -d_out_fake.mean()
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            iteration += 1

        t_epoch_finish = time.time()
        print(
            'epoch {:3d}/300 || D_Loss: {:.4f} || G_Loss: {:.4f} || time: {:.4f} sec.'
            .format(epoch, epoch_d_loss / batch_size,
                    epoch_g_loss / batch_size, t_epoch_finish - t_epoch_start))

    # --- 画像生成・可視化する --- #
    test_size = 5  # 可視化する個数
    input_z = torch.randn(test_size, z_dim)
    input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
    G.eval()
    fake_images, at_map1, at_map2 = G(input_z.to(device))

    fig = plt.figure(figsize=(15, 6))
    for i in range(0, 5):
        # top: fake image
        plt.subplot(2, 5, i + 1)
        plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

        # middle: atmap1
        plt.subplot(2, 5, i + 6)
        am = at_map1[i].view(16, 16, 16, 16)
        am = am[7][7]
        plt.imshow(am.cpu().detach().numpy(), 'Reds')

    plt.savefig('visualization.png')