Пример #1
0
all_datasets_path = '/media/tdanka/B8703F33703EF828/tdanka/data'
train_dataset_loc = os.path.join(all_datasets_path, 'stage1_train_merged/loc.csv')
test_dataset_loc = os.path.join(all_datasets_path, 'stage1_test/loc.csv')
results_root_path = '/media/tdanka/B8703F33703EF828/tdanka/results'

tf = make_transform(size=(128, 128), p_flip=0.5, color_jitter_params=(0, 0, 0, 0))
train_dataset = JointlyTransformedDataset(train_dataset_loc, transform=tf, remove_alpha=True)
test_dataset = TestFromFolder(test_dataset_loc, transform=T.ToTensor(), remove_alpha=True)
train_original_dataset = TestFromFolder(train_dataset_loc, transform=T.ToTensor(), remove_alpha=True)

model_name = 'GAN_overnight_2018_03_06'
g = UNet(3, 1)
g_optimizer = optim.Adam(g.parameters(), lr=1e-4)
d = Discriminator(4, 1)
d_loss = nn.BCELoss()
d_optimizer = optim.Adam(d.parameters(), lr=1e-4)

gan = GAN(
    g=g, g_optim=g_optimizer, d=d, d_loss=d_loss, d_optim=d_optimizer,
    model_name=model_name, results_root_path=results_root_path
)

n_rounds = 1000
for round_idx in range(n_rounds):
    print('***** Round no. %d *****' % round_idx)
    gan.train_discriminator(train_dataset, n_epochs=10, n_batch=16, verbose=False)
    gan.train_generator(train_dataset, n_epochs=10, n_batch=16, verbose=False)

    if round_idx % 10 == 0:
        gan.visualize(train_dataset, folder_name='compare_round_%d' % round_idx)
        gan.predict(test_dataset, folder_name='test_round_%d' % round_idx)
Пример #2
0
device = torch.device("cuda:0")


netG = Generator(1).to(device)
netG.apply(weights_init)


netD = Discriminator(1).to(device)
netD.apply(weights_init)

criterion = nn.BCELoss()

# Setup Adam optimizers for both G and D
#optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))# SGD
optimizerD = optim.SGD(netD.parameters(), lr=0.01)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

print("Data setup")

rgb_wi_gt_data = loader(["../../data/small_potsdam/rgb","../../data/small_potsdam/y"],img_size,batch_size,transformations=[lambda x: x-load.get_mean("../../data/vaihingen/rgb"),rgb_to_binary])
data_wi_gt = rgb_wi_gt_data.generate_patch()



ones = torch.FloatTensor(batch_size).fill_(1).cuda()
zero = torch.FloatTensor(batch_size).fill_(0).cuda()

print("Running...")
counter = 1
for epoch in range(1000):
Пример #3
0
class ModuleTrain:
    def __init__(self, opt, best_loss=0.2):
        self.opt = opt
        self.best_loss = best_loss  # 正确率这个值,才会保存模型

        self.netd = Discriminator(self.opt)
        self.netg = Generator(self.opt)
        self.use_gpu = False

        # 加载模型
        if os.path.exists(self.opt.netd_path):
            self.load_netd(self.opt.netd_path)
        else:
            print('[Load model] error: %s not exist !!!' % self.opt.netd_path)
        if os.path.exists(self.opt.netg_path):
            self.load_netg(self.opt.netg_path)
        else:
            print('[Load model] error: %s not exist !!!' % self.opt.netg_path)

        # DataLoader初始化
        self.transform_train = T.Compose([
            T.Resize((self.opt.img_size, self.opt.img_size)),
            T.ToTensor(),
            T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]),
        ])
        train_dataset = ImageFolder(root=self.opt.data_path,
                                    transform=self.transform_train)
        self.train_loader = DataLoader(dataset=train_dataset,
                                       batch_size=self.opt.batch_size,
                                       shuffle=True,
                                       num_workers=self.opt.num_workers,
                                       drop_last=True)

        # 优化器和损失函数
        # self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.5)
        self.optimizer_g = optim.Adam(self.netg.parameters(),
                                      lr=self.opt.lr1,
                                      betas=(self.opt.beta1, 0.999))
        self.optimizer_d = optim.Adam(self.netd.parameters(),
                                      lr=self.opt.lr2,
                                      betas=(self.opt.beta1, 0.999))
        self.criterion = torch.nn.BCELoss()

        self.true_labels = Variable(torch.ones(self.opt.batch_size))
        self.fake_labels = Variable(torch.zeros(self.opt.batch_size))
        self.fix_noises = Variable(
            torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
        self.noises = Variable(
            torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))

        # gpu or cpu
        if self.opt.use_gpu and torch.cuda.is_available():
            self.use_gpu = True
        else:
            self.use_gpu = False
        if self.use_gpu:
            print('[use gpu] ...')
            self.netd.cuda()
            self.netg.cuda()
            self.criterion.cuda()
            self.true_labels = self.true_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.fix_noises = self.fix_noises.cuda()
            self.noises = self.noises.cuda()
        else:
            print('[use cpu] ...')

        pass

    def train(self, save_best=True):
        print('[train] epoch: %d' % self.opt.max_epoch)
        for epoch_i in range(self.opt.max_epoch):
            loss_netd = 0.0
            loss_netg = 0.0
            correct = 0

            print('================================================')
            for ii, (img, target) in enumerate(self.train_loader):  # 训练
                real_img = Variable(img)
                if self.opt.use_gpu:
                    real_img = real_img.cuda()

                # 训练判别器
                if (ii + 1) % self.opt.d_every == 0:
                    self.optimizer_d.zero_grad()
                    # 尽可能把真图片判别为1
                    output = self.netd(real_img)
                    error_d_real = self.criterion(output, self.true_labels)
                    error_d_real.backward()

                    # 尽可能把假图片判别为0
                    self.noises.data.copy_(
                        torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
                    fake_img = self.netg(self.noises).detach()  # 根据噪声生成假图
                    fake_output = self.netd(fake_img)
                    error_d_fake = self.criterion(fake_output,
                                                  self.fake_labels)
                    error_d_fake.backward()
                    self.optimizer_d.step()

                    loss_netd += (error_d_real.item() + error_d_fake.item())

                # 训练生成器
                if (ii + 1) % self.opt.g_every == 0:
                    self.optimizer_g.zero_grad()
                    self.noises.data.copy_(
                        torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
                    fake_img = self.netg(self.noises)
                    fake_output = self.netd(fake_img)
                    # 尽可能让判别器把假图片也判别为1
                    error_g = self.criterion(fake_output, self.true_labels)
                    error_g.backward()
                    self.optimizer_g.step()

                    loss_netg += error_g

            loss_netd /= (len(self.train_loader) * 2)
            loss_netg /= len(self.train_loader)
            print('[Train] Epoch: {} \tNetD Loss: {:.6f} \tNetG Loss: {:.6f}'.
                  format(epoch_i, loss_netd, loss_netg))
            if save_best is True:
                if (loss_netg + loss_netd) / 2 < self.best_loss:
                    self.best_loss = (loss_netg + loss_netd) / 2
                    self.save(self.netd, self.opt.best_netd_path)  # 保存最好的模型
                    self.save(self.netg, self.opt.best_netg_path)  # 保存最好的模型
                    print('[save best] ...')

            # self.vis()

            if (epoch_i + 1) % 5 == 0:
                self.image_gan()

        self.save(self.netd, self.opt.netd_path)  # 保存最好的模型
        self.save(self.netg, self.opt.netg_path)  # 保存最好的模型

    def vis(self):
        fix_fake_imgs = self.netg(self.opt.fix_noises)
        visdom.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                      win='fixfake')

    def image_gan(self):
        noises = torch.randn(self.opt.gen_search_num, self.opt.nz, 1,
                             1).normal_(self.opt.gen_mean, self.opt.gen_std)
        with torch.no_grad():
            noises = Variable(noises)

        if self.use_gpu:
            noises = noises.cuda()

        fake_img = self.netg(noises)
        scores = self.netd(fake_img).data
        indexs = scores.topk(self.opt.gen_num)[1]
        result = list()
        for ii in indexs:
            result.append(fake_img.data[ii])

        torchvision.utils.save_image(torch.stack(result),
                                     self.opt.gen_img,
                                     normalize=True,
                                     range=(-1, 1))

        #     # print(correct)
        #     # print(len(self.train_loader.dataset))
        #     train_loss /= len(self.train_loader)
        #     acc = float(correct) / float(len(self.train_loader.dataset))
        #     print('[Train] Epoch: {} \tLoss: {:.6f}\tAcc: {:.6f}\tlr: {}'.format(epoch_i, train_loss, acc, self.lr))
        #
        #     test_acc = self.test()
        #     if save_best is True:
        #         if test_acc > self.best_acc:
        #             self.best_acc = test_acc
        #             str_list = self.model_file.split('.')
        #             best_model_file = ""
        #             for str_index in range(len(str_list)):
        #                 best_model_file = best_model_file + str_list[str_index]
        #                 if str_index == (len(str_list) - 2):
        #                     best_model_file += '_best'
        #                 if str_index != (len(str_list) - 1):
        #                     best_model_file += '.'
        #             self.save(best_model_file)                                  # 保存最好的模型
        #
        # self.save(self.model_file)

    def test(self):
        test_loss = 0.0
        correct = 0

        time_start = time.time()
        # 测试集
        for data, target in self.test_loader:
            data, target = Variable(data), Variable(target)

            if self.use_gpu:
                data = data.cuda()
                target = target.cuda()

            output = self.model(data)
            # sum up batch loss
            if self.use_gpu:
                loss = self.loss(output, target)
            else:
                loss = self.loss(output, target)
            test_loss += loss.item()

            predict = torch.argmax(output, 1)
            correct += (predict == target).sum().data

        time_end = time.time()
        time_avg = float(time_end - time_start) / float(
            len(self.test_loader.dataset))
        test_loss /= len(self.test_loader)
        acc = float(correct) / float(len(self.test_loader.dataset))

        print('[Test] set: Test loss: {:.6f}\t Acc: {:.6f}\t time: {:.6f} \n'.
              format(test_loss, acc, time_avg))
        return acc

    def load_netd(self, name):
        print('[Load model netd] %s ...' % name)
        self.netd.load_state_dict(torch.load(name))

    def load_netg(self, name):
        print('[Load model netg] %s ...' % name)
        self.netg.load_state_dict(torch.load(name))

    def save(self, model, name):
        print('[Save model] %s ...' % name)
        torch.save(model.state_dict(), name)
Пример #4
0
def adversarial_train(img_data_loader, vgg_cutoff_layer=36, num_epochs=2000, decay_factor=0.1, initial_lr=0.0001, adversarial_loss_weight=0.001, checkpoint=None, save=True):
    if checkpoint is not None:
        imported_checkpoint = torch.load(checkpoint)
        generator = imported_checkpoint['generator']
        starting_epoch = 0
        discriminator = Discriminator()
        generator_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()), lr=initial_lr)
        discriminator_optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()), lr=initial_lr)
    else:
        generator = Generator()
        starting_epoch = 0
        discriminator = Discriminator()
        generator_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()), lr=initial_lr)
        discriminator_optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()), lr=initial_lr)
    
    vgg = ChoppedVGG19(vgg_cutoff_layer)
    
    # generator_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()), lr=initial_lr)
    # discriminator_optimizer = optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()), lr=initial_lr)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Push everything to gpu if it's available
    content_criterion = nn.MSELoss().to(device)
    adversarial_criterion = nn.BCEWithLogitsLoss().to(device)
    generator.to(device)
    discriminator.to(device)
    vgg.to(device)

    for epoch in range(starting_epoch, num_epochs):
        running_perceptual_loss = 0.0
        running_adversarial_loss = 0.0
        for ii, (hr_imgs, lr_imgs) in enumerate(tqdm(img_data_loader)):
            hr_imgs, lr_imgs = hr_imgs.to(device), lr_imgs.to(device)

            # Forwardpropagate through generator
            sr_imgs = generator(lr_imgs)
            sr_vgg_feature_maps = vgg(sr_imgs)
            hr_vgg_feature_maps = vgg(hr_imgs).detach()

            # Try and discriminate fakes
            sr_discriminator_logprob = discriminator(sr_imgs)
            
            # Calculate loss for generator
            content_loss = content_criterion(sr_vgg_feature_maps, hr_vgg_feature_maps)
            adversarial_loss = adversarial_criterion(sr_discriminator_logprob, torch.ones_like(sr_discriminator_logprob))
            perceptual_loss = content_loss + adversarial_loss_weight*adversarial_loss
            running_perceptual_loss += perceptual_loss.item()
            del sr_vgg_feature_maps, hr_vgg_feature_maps, sr_discriminator_logprob

            # Backpropagate and update generator
            generator_optimizer.zero_grad()
            perceptual_loss.backward()
            generator_optimizer.step()

            # Now for the discriminator
            sr_discriminator_logprob = discriminator(sr_imgs.detach())
            hr_discriminator_logprob = discriminator(hr_imgs)
            adversarial_loss = adversarial_criterion(sr_discriminator_logprob, torch.zeros_like(sr_discriminator_logprob)) + adversarial_criterion(hr_discriminator_logprob, torch.ones_like(hr_discriminator_logprob))
            running_adversarial_loss += adversarial_loss.item()

            # Backpropagate and update discriminator
            discriminator_optimizer.zero_grad()
            adversarial_loss.backward()
            discriminator_optimizer.step()
            del lr_imgs, hr_imgs, sr_imgs, sr_discriminator_logprob, hr_discriminator_logprob
        print("Epoch number {}".format(epoch))
        print("Average Perceptual Loss: {}".format(running_perceptual_loss/len(img_data_loader)))
        print("Average Adversarial Loss: {}".format(running_adversarial_loss/len(img_data_loader)))

        if save:
        # Save the final pretrained model if you're going to continue later
            torch.save({'epoch': epoch,
                        'generator': generator,
                        'generator_optimizer': generator_optimizer,
                        'discriminator': discriminator,
                        'discriminator_optimizer':discriminator_optimizer},
                        'adversarial_training_checkpoint_CelebA_HQ.pth.tar')
class GanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(GanTrainer, self).__init__(train_loader, test_loader,
                                         valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the auto-encoder
        self.use_autoencoder = False
        if trainer_args.autoencoder_path and os.path.exists(
                trainer_args.autoencoder_path):
            self.use_autoencoder = True
            self.autoencoder = AutoEncoder(general_args=general_args).to(
                self.device)
            self.load_pretrained_autoencoder(trainer_args.autoencoder_path)
            self.autoencoder.eval()

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.generator_time_criterion = nn.MSELoss()
        self.generator_frequency_criterion = nn.MSELoss()
        self.generator_autoencoder_criterion = nn.MSELoss()

        # Define labels
        self.real_label = 1
        self.generated_label = 0

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_freq = trainer_args.lambda_freq
        self.lambda_autoencoder = trainer_args.lambda_autoencoder

        # Spectrogram converter
        self.spectrogram = Spectrogram(normalized=True).to(self.device)

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Boolean if the generator receives the feedback from the discriminator
        self.use_adversarial = trainer_args.use_adversarial

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def load_pretrained_autoencoder(self, autoencoder_path):
        """
        Loads a pre-trained auto-encoder. Can be used to infer
        :param autoencoder_path: location of the pre-trained auto-encoder (string).
        :return: None
        """
        checkpoint = torch.load(autoencoder_path, map_location=self.device)
        self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])

    def train(self, epochs):
        """
        Trains the GAN for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                self.generator.train()
                self.discriminator.train()
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)
                batch_size = input_batch.shape[0]

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                # Train the discriminator with real data
                self.discriminator_optimizer.zero_grad()
                label = torch.full((batch_size, ),
                                   self.real_label,
                                   device=self.device)
                output = self.discriminator(target_batch)

                # Compute and store the discriminator loss on real data
                loss_discriminator_real = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['real'].append(
                    loss_discriminator_real.item())
                loss_discriminator_real.backward()

                # Train the discriminator with fake data
                generated_batch = self.generator(input_batch)
                label.fill_(self.generated_label)
                output = self.discriminator(generated_batch.detach())

                # Compute and store the discriminator loss on fake data
                loss_discriminator_generated = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['fake'].append(
                    loss_discriminator_generated.item())
                loss_discriminator_generated.backward()

                # Update the discriminator weights
                self.discriminator_optimizer.step()

                ############################
                # Update G network: maximize log(D(G(z)))
                ###########################
                self.generator_optimizer.zero_grad()

                # Get the spectrogram
                specgram_target_batch = self.spectrogram(target_batch)
                specgram_fake_batch = self.spectrogram(generated_batch)

                # Fake labels are real for the generator cost
                label.fill_(self.real_label)
                output = self.discriminator(generated_batch)

                # Compute the generator loss on fake data
                # Get the adversarial loss
                loss_generator_adversarial = torch.zeros(size=[1],
                                                         device=self.device)
                if self.use_adversarial:
                    loss_generator_adversarial = self.adversarial_criterion(
                        output, torch.unsqueeze(label, dim=1))
                self.train_losses['generator_adversarial'].append(
                    loss_generator_adversarial.item())

                # Get the L2 loss in time domain
                loss_generator_time = self.generator_time_criterion(
                    generated_batch, target_batch)
                self.train_losses['time_l2'].append(loss_generator_time.item())

                # Get the L2 loss in frequency domain
                loss_generator_frequency = self.generator_frequency_criterion(
                    specgram_fake_batch, specgram_target_batch)
                self.train_losses['freq_l2'].append(
                    loss_generator_frequency.item())

                # Get the L2 loss in embedding space
                loss_generator_autoencoder = torch.zeros(size=[1],
                                                         device=self.device,
                                                         requires_grad=True)
                if self.use_autoencoder:
                    # Get the embeddings
                    _, embedding_target_batch = self.autoencoder(target_batch)
                    _, embedding_generated_batch = self.autoencoder(
                        generated_batch)
                    loss_generator_autoencoder = self.generator_autoencoder_criterion(
                        embedding_generated_batch, embedding_target_batch)
                    self.train_losses['autoencoder_l2'].append(
                        loss_generator_autoencoder.item())

                # Combine the different losses
                loss_generator = self.lambda_adv * loss_generator_adversarial + loss_generator_time + \
                                 self.lambda_freq * loss_generator_frequency + \
                                 self.lambda_autoencoder * loss_generator_autoencoder

                # Back-propagate and update the generator weights
                loss_generator.backward()
                self.generator_optimizer.step()

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Frequency: {} \n' \
                              '\t\t Autoencoder {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Real {} \n' \
                              '\t\t Fake {} \n'.format(i,
                                                       loss_generator_time.item(),
                                                       loss_generator_frequency.item(),
                                                       loss_generator_autoencoder.item(),
                                                       loss_generator_adversarial.item(),
                                                       loss_discriminator_real.item(),
                                                       loss_discriminator_generated.item())
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()
            # if self.need_saving:
            #     self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': [], 'freq_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            # Get the spectrogram
            specgram_target_batch = self.spectrogram(target_batch)
            specgram_generated_batch = self.spectrogram(generated_batch)

            loss_generator_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_generator_time.item())
            loss_generator_frequency = self.generator_frequency_criterion(
                specgram_generated_batch, specgram_target_batch)
            batch_losses['freq_l2'].append(loss_generator_frequency.item())

        # Store the validation losses
        self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2']))
        self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n' \
                  '\t Frequency: {} \n'.format(self.epoch,
                                               np.mean(np.mean(batch_losses['time_l2'])),
                                               np.mean(np.mean(batch_losses['freq_l2'])))
        print(message)

        # Check if the loss is decreasing
        self.check_improvement()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, self.savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(
            checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        self.discriminator_optimizer.load_state_dict(
            checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
Пример #6
0
def main():
    env = DialogEnvironment()
    experiment_name = args.logdir.split('/')[1] #model name

    torch.manual_seed(args.seed)

    #TODO
    actor = Actor(hidden_size=args.hidden_size,num_layers=args.num_layers,device='cuda',input_size=args.input_size,output_size=args.input_size)
    critic = Critic(hidden_size=args.hidden_size,num_layers=args.num_layers,input_size=args.input_size,seq_len=args.seq_len)
    discrim = Discriminator(hidden_size=args.hidden_size,num_layers=args.hidden_size,input_size=args.input_size,seq_len=args.seq_len)
    
    actor.to(device), critic.to(device), discrim.to(device)
    
    actor_optim = optim.Adam(actor.parameters(), lr=args.learning_rate)
    critic_optim = optim.Adam(critic.parameters(), lr=args.learning_rate, 
                              weight_decay=args.l2_rate) 
    discrim_optim = optim.Adam(discrim.parameters(), lr=args.learning_rate)

    # load demonstrations

    writer = SummaryWriter(args.logdir)

    if args.load_model is not None: #TODO
        saved_ckpt_path = os.path.join(os.getcwd(), 'save_model', str(args.load_model))
        ckpt = torch.load(saved_ckpt_path)

        actor.load_state_dict(ckpt['actor'])
        critic.load_state_dict(ckpt['critic'])
        discrim.load_state_dict(ckpt['discrim'])


    
    episodes = 0
    train_discrim_flag = True

    for iter in range(args.max_iter_num):
        actor.eval(), critic.eval()
        memory = deque()

        steps = 0
        scores = []
        similarity_scores = []
        while steps < args.total_sample_size: 
            scores = []
            similarity_scores = []
            state, expert_action, raw_state, raw_expert_action = env.reset()
            score = 0
            similarity_score = 0
            state = state[:args.seq_len,:]
            expert_action = expert_action[:args.seq_len,:]
            state = state.to(device)
            expert_action = expert_action.to(device)
            for _ in range(10000): 

                steps += 1

                mu, std = actor(state.resize(1,args.seq_len,args.input_size)) #TODO: gotta be a better way to resize. 
                action = get_action(mu.cpu(), std.cpu())[0]
                for i in range(5):
                    emb_sum = expert_action[i,:].sum().cpu().item()
                    if emb_sum == 0:
                       # print(i)
                        action[i:,:] = 0 # manual padding
                        break

                done= env.step(action)
                irl_reward = get_reward(discrim, state, action, args)
                if done:
                    mask = 0
                else:
                    mask = 1


                memory.append([state, torch.from_numpy(action).to(device), irl_reward, mask,expert_action])
                score += irl_reward
                similarity_score += get_cosine_sim(expert=expert_action,action=action.squeeze(),seq_len=5)
                #print(get_cosine_sim(s1=expert_action,s2=action.squeeze(),seq_len=5),'sim')
                if done:
                    break

            episodes += 1
            scores.append(score)
            similarity_scores.append(similarity_score)

        score_avg = np.mean(scores)
        similarity_score_avg = np.mean(similarity_scores)
        print('{}:: {} episode score is {:.2f}'.format(iter, episodes, score_avg))
        print('{}:: {} episode similarity score is {:.2f}'.format(iter, episodes, similarity_score_avg))

        actor.train(), critic.train(), discrim.train()
        if train_discrim_flag:
            expert_acc, learner_acc = train_discrim(discrim, memory, discrim_optim, args) 
            print("Expert: %.2f%% | Learner: %.2f%%" % (expert_acc * 100, learner_acc * 100))
            writer.add_scalar('log/expert_acc', float(expert_acc), iter) #logg
            writer.add_scalar('log/learner_acc', float(learner_acc), iter) #logg
            writer.add_scalar('log/avg_acc', float(learner_acc + expert_acc)/2, iter) #logg
            if args.suspend_accu_exp is not None: #only if not None do we check.
                if expert_acc > args.suspend_accu_exp and learner_acc > args.suspend_accu_gen:
                    train_discrim_flag = False

        train_actor_critic(actor, critic, memory, actor_optim, critic_optim, args)
        writer.add_scalar('log/score', float(score_avg), iter)
        writer.add_scalar('log/similarity_score', float(similarity_score_avg), iter)
        writer.add_text('log/raw_state', raw_state[0],iter)
        raw_action = get_raw_action(action) #TODO
        writer.add_text('log/raw_action', raw_action,iter)
        writer.add_text('log/raw_expert_action', raw_expert_action,iter)

        if iter % 100:
            score_avg = int(score_avg)
            # Open a file with access mode 'a'
            file_object = open(experiment_name+'.txt', 'a')

            result_str = str(iter) + '|' + raw_state[0] + '|' + raw_action + '|' + raw_expert_action + '\n'
            # Append at the end of file
            file_object.write(result_str)
            # Close the file
            file_object.close()

            model_path = os.path.join(os.getcwd(),'save_model')
            if not os.path.isdir(model_path):
                os.makedirs(model_path)

            ckpt_path = os.path.join(model_path, experiment_name + '_ckpt_'+ str(score_avg)+'.pth.tar')

            save_checkpoint({
                'actor': actor.state_dict(),
                'critic': critic.state_dict(),
                'discrim': discrim.state_dict(),
                'args': args,
                'score': score_avg,
            }, filename=ckpt_path)
        return h3

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        # z = z.view(args.batch_size, z_dim, 1, 1)
        z = z.view(args.batch_size, z_dim)
        return self.decode(z), mu, logvar

model = VAE()
discrim = Discriminator()
if use_cuda:
    model.cuda()
    discrim.cuda()
optimizer = optim.Adam(model.parameters(), lr = args.lr)
discrim_optimizer = optim.Adam(discrim.parameters(), lr = args.discrim_lr)

# Reconstruction + KL divergence losses summed over all elements and batch
A, B, C = 224, 224, 3
image_size = A * B * C
def loss_function(recon_x, x, mu, logvar):
    # BCE = F.binary_cross_entropy(recon_x.view(-1, image_size), x.view(-1, image_size), size_average=False)
    # BCE = F.binary_cross_entropy(recon_x, x)
    ## define the GAN loss here
    label = Variable(torch.ones(args.batch_size).type(Tensor))
    ## TODO: see if its a variable
    discrim_output = discrim(recon_x)
    BCE = F.binary_cross_entropy(discrim_output, label)


    # see Appendix B from VAE paper:
Пример #8
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=400,
                        help='number of epochs of training')
    parser.add_argument('--batchSize',
                        type=int,
                        default=10,
                        help='size of the batches')
    parser.add_argument('--dataroot',
                        type=str,
                        default='datasets/genderchange/',
                        help='root directory of the dataset')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument(
        '--decay_epoch',
        type=int,
        default=100,
        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size',
                        type=int,
                        default=256,
                        help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc',
                        type=int,
                        default=3,
                        help='number of channels of input data')
    parser.add_argument('--output_nc',
                        type=int,
                        default=3,
                        help='number of channels of output data')
    parser.add_argument('--cuda',
                        action='store_true',
                        help='use GPU computation')
    parser.add_argument(
        '--n_cpu',
        type=int,
        default=8,
        help='number of cpu threads to use during batch generation')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    ###### Definition of variables ######
    # Networks
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    if opt.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Dataset loader
    transforms_ = [
        transforms.Resize(int(opt.size * 1.2), Image.BICUBIC),
        transforms.CenterCrop(opt.size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=opt.n_cpu,
                            drop_last=True)

    # Plot Loss and Images in Tensorboard
    experiment_dir = 'logs/{}@{}'.format(
        opt.dataroot.split('/')[1],
        datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
    os.makedirs(experiment_dir, exist_ok=True)
    writer = SummaryWriter(os.path.join(experiment_dir, "tb"))

    metric_dict = defaultdict(list)
    n_iters_total = 0

    ###################################
    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(
                same_B, real_B) * 5.0  # [batchSize, 3, ImgSize, ImgSize]

            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(
                same_A, real_A) * 5.0  # [batchSize, 3, ImgSize, ImgSize]

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B).view(-1)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)  # [batchSize]

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A).view(-1)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)  # [batchSize]

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(
                recovered_A, real_A) * 10.0  # [batchSize, 3, ImgSize, ImgSize]

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(
                recovered_B, real_B) * 10.0  # [batchSize, 3, ImgSize, ImgSize]

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

            loss_G.backward()
            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A).view(-1)
            loss_D_real = criterion_GAN(pred_real, target_real)  # [batchSize]

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach()).view(-1)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)  # [batchSize]

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B).view(-1)
            loss_D_real = criterion_GAN(pred_real, target_real)  # [batchSize]

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach()).view(-1)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)  # [batchSize]

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            metric_dict['loss_G'].append(loss_G.item())
            metric_dict['loss_G_identity'].append(loss_identity_A.item() +
                                                  loss_identity_B.item())
            metric_dict['loss_G_GAN'].append(loss_GAN_A2B.item() +
                                             loss_GAN_B2A.item())
            metric_dict['loss_G_cycle'].append(loss_cycle_ABA.item() +
                                               loss_cycle_BAB.item())
            metric_dict['loss_D'].append(loss_D_A.item() + loss_D_B.item())

            for title, value in metric_dict.items():
                writer.add_scalar('train/{}'.format(title), value[-1],
                                  n_iters_total)

            n_iters_total += 1

        print("""
        -----------------------------------------------------------
        Epoch : {} Finished
        Loss_G : {}
        Loss_G_identity : {}
        Loss_G_GAN : {}
        Loss_G_cycle : {}
        Loss_D : {}
        -----------------------------------------------------------
        """.format(epoch, loss_G, loss_identity_A + loss_identity_B,
                   loss_GAN_A2B + loss_GAN_B2A,
                   loss_cycle_ABA + loss_cycle_BAB, loss_D_A + loss_D_B))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints

        if loss_G.item() < 2.5:
            os.makedirs(os.path.join(experiment_dir, str(epoch)),
                        exist_ok=True)
            torch.save(netG_A2B.state_dict(),
                       '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch))
            torch.save(netG_B2A.state_dict(),
                       '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch))
            torch.save(netD_A.state_dict(),
                       '{}/{}/netD_A.pth'.format(experiment_dir, epoch))
            torch.save(netD_B.state_dict(),
                       '{}/{}/netD_B.pth'.format(experiment_dir, epoch))
        elif epoch > 100 and epoch % 40 == 0:
            os.makedirs(os.path.join(experiment_dir, str(epoch)),
                        exist_ok=True)
            torch.save(netG_A2B.state_dict(),
                       '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch))
            torch.save(netG_B2A.state_dict(),
                       '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch))
            torch.save(netD_A.state_dict(),
                       '{}/{}/netD_A.pth'.format(experiment_dir, epoch))
            torch.save(netD_B.state_dict(),
                       '{}/{}/netD_B.pth'.format(experiment_dir, epoch))

        for title, value in metric_dict.items():
            writer.add_scalar("train/{}_epoch".format(title), np.mean(value),
                              epoch)
Пример #9
0
# 初始化模型
classifier = Classifier()
critic = Discriminator(input_dims=params.d_input_dims,
                       hidden_dims=params.d_hidden_dims,
                       output_dims=params.d_output_dims)
generator = Generator()

criterion = nn.CrossEntropyLoss()

# special for target
generator_larger = Generator_Larger()

optimizer_c = optim.Adam(classifier.parameters(),
                         lr=params.learning_rate,
                         betas=(params.beta1, params.beta2))
optimizer_d = optim.Adam(critic.parameters(),
                         lr=params.learning_rate,
                         betas=(params.beta1, params.beta2))
optimizer_g = optim.Adam(generator.parameters(),
                         lr=params.learning_rate,
                         betas=(params.beta1, params.beta2))

optimizer_g_l = optim.Adam(generator_larger.parameters(),
                           lr=params.learning_rate,
                           betas=(params.beta1, params.beta2))
data_itr_src = get_data_iter("MNIST", train=True)
data_itr_tgt = get_data_iter("USPS", train=True)

pos_labels = Variable(torch.Tensor([1]))
neg_lables = Variable(torch.Tensor([-1]))
g_step = 0
Пример #10
0
class Seq2SeqCycleGAN:
    def __init__(self,
                 model_config,
                 train_config,
                 vocab,
                 max_len,
                 mode='train'):
        self.mode = mode

        self.model_config = model_config
        self.train_config = train_config

        self.vocab = vocab
        self.vocab_size = self.vocab.num_words
        self.max_len = max_len

        # self.embedding_layer = nn.Embedding(vocab_size, model_config['embedding_size'], padding_idx=PAD_token)
        self.embedding_layer = nn.Sequential(
            nn.Linear(self.vocab_size, self.model_config['embedding_size']),
            nn.Sigmoid())

        self.G_AtoB = Generator(self.embedding_layer,
                                self.model_config,
                                self.train_config,
                                self.vocab_size,
                                self.max_len,
                                mode=self.mode).cuda()
        self.G_BtoA = Generator(self.embedding_layer,
                                self.model_config,
                                self.train_config,
                                self.vocab_size,
                                self.max_len,
                                mode=self.mode).cuda()

        if self.mode == 'train':
            self.D_B = Discriminator(self.embedding_layer, self.model_config,
                                     self.train_config).cuda()
            self.D_A = Discriminator(self.embedding_layer, self.model_config,
                                     self.train_config).cuda()

            if self.train_config['continue_train']:
                self.embedding_layer.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_embedding_layer.pth'))
                self.G_AtoB.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_G_AtoB.pth'))
                self.G_BtoA.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_G_BtoA.pth'))
                self.D_B.load_state_dict(
                    torch.load(self.train_config['which_epoch'] + '_D_B.pth'))
                self.D_A.load_state_dict(
                    torch.load(self.train_config['which_epoch'] + '_D_A.pth'))

            self.embedding_layer.train()
            self.G_AtoB.train()
            self.G_BtoA.train()
            self.D_B.train()
            self.D_A.train()

            self.criterionBCE = nn.BCELoss().cuda()
            self.criterionCE = nn.CrossEntropyLoss().cuda()

            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.embedding_layer.parameters(), self.G_AtoB.parameters(),
                self.G_BtoA.parameters()),
                                                lr=train_config['base_lr'],
                                                betas=(0.9, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.embedding_layer.parameters(), self.D_A.parameters(),
                self.D_B.parameters()),
                                                lr=train_config['base_lr'],
                                                betas=(0.9, 0.999))

            self.real_label = torch.ones(
                (train_config['batch_size'], 1)).cuda()
            self.fake_label = torch.zeros(
                (train_config['batch_size'], 1)).cuda()
        else:
            self.embedding_layer.load_state_dict(
                torch.load(self.train_config['which_epoch'] +
                           '_embedding_layer.pth'))
            self.G_AtoB.load_state_dict(
                torch.load(self.train_config['which_epoch'] + '_G_AtoB.pth'))
            self.G_BtoA.load_state_dict(
                torch.load(self.train_config['which_epoch'] + '_G_BtoA.pth'))

            self.embedding_layer.eval()
            self.G_AtoB.eval()
            self.G_BtoA.eval()

    def backward_D_basic(self, netD, real, real_addn_feats, fake,
                         fake_addn_feats):
        netD.hidden = netD.init_hidden()
        pred_real = netD(real, real_addn_feats)
        loss_D_real = self.criterionBCE(pred_real, self.real_label)

        netD.hidden = netD.init_hidden()
        pred_fake = netD(fake.detach(), fake_addn_feats)
        loss_D_fake = self.criterionBCE(pred_fake, self.fake_label)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()

        self.clip_gradient(self.embedding_layer)
        self.clip_gradient(netD)

        return loss_D

    def backward_D_A(self):
        self.loss_D_A = self.backward_D_basic(
            self.D_A, self.real_A, self.real_A_addn_feats, self.fake_A,
            self.fake_A_addn_feats) * 10

    def backward_D_B(self):
        self.loss_D_B = self.backward_D_basic(
            self.D_B, self.real_B, self.real_B_addn_feats, self.fake_B,
            self.fake_B_addn_feats) * 10

    def backward_G(self):
        self.D_B.hidden = self.D_B.init_hidden()
        self.fake_B_addn_feats = get_addn_feats(self.fake_B, self.vocab).cuda()
        self.loss_G_AtoB = self.criterionBCE(
            self.D_B(self.fake_B, self.fake_B_addn_feats),
            self.real_label) * 10

        self.D_A.hidden = self.D_A.init_hidden()
        self.fake_A_addn_feats = get_addn_feats(self.fake_A, self.vocab).cuda()
        self.loss_G_BtoA = self.criterionBCE(
            self.D_A(self.fake_A, self.fake_A_addn_feats),
            self.real_label) * 10

        if self.rec_A.size(0) != self.real_A_label.size(0):
            self.real_A, self.rec_A, self.real_A_label = self.update_label_sizes(
                self.real_A, self.rec_A, self.real_A_label)
        self.loss_cycle_A = self.criterionCE(self.rec_A,
                                             self.real_A_label)  #* lambda_A

        if self.rec_B.size(0) != self.real_B_label.size(0):
            self.real_B, self.rec_B, self.real_B_label = self.update_label_sizes(
                self.real_B, self.rec_B, self.real_B_label)
        self.loss_cycle_B = self.criterionCE(self.rec_B,
                                             self.real_B_label)  #* lambda_B

        self.idt_B = self.G_AtoB(self.real_B)
        if self.idt_B.size(0) != self.real_B_label.size(0):
            self.real_B, self.idt_B, self.real_B_label = self.update_label_sizes(
                self.real_B, self.idt_B, self.real_B_label)
        self.loss_idt_B = self.criterionCE(
            self.idt_B, self.real_B_label)  #* lambda_B * lambda_idt

        self.idt_A = self.G_BtoA(self.real_A)
        if self.idt_A.size(0) != self.real_A_label.size(0):
            self.real_A, self.idt_A, self.real_A_label = self.update_label_sizes(
                self.real_A, self.idt_A, self.real_A_label)
        self.loss_idt_A = self.criterionCE(
            self.idt_A, self.real_A_label)  #* lambda_A * lambda_idt

        self.loss_G = self.loss_G_AtoB + self.loss_G_BtoA + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

        self.clip_gradient(self.embedding_layer)
        self.clip_gradient(self.G_AtoB)
        self.clip_gradient(self.G_BtoA)

    def forward(self, real_A, real_A_addn_feats, real_B, real_B_addn_feats):
        self.real_A = real_A
        self.real_A_addn_feats = real_A_addn_feats
        self.real_A_label = self.real_A.max(dim=1)[1]

        self.real_B = real_B
        self.real_B_addn_feats = real_B_addn_feats
        self.real_B_label = self.real_B.max(dim=1)[1]

        self.fake_B = F.softmax(self.G_AtoB.forward(self.real_A), dim=1)
        self.fake_A = F.softmax(self.G_BtoA.forward(self.real_B), dim=1)

        if self.mode == 'train':
            self.rec_A = self.G_BtoA.forward(self.fake_B)
            self.rec_B = self.G_AtoB.forward(self.fake_A)

        else:
            real_A_list = self.real_A.max(dim=1)[1].tolist()
            real_B_list = self.real_B.max(dim=1)[1].tolist()

            fake_B_list = self.fake_B.max(dim=1)[1].tolist()
            fake_A_list = self.fake_A.max(dim=1)[1].tolist()

            print('Input (Shakespeare):', idx_to_sent(real_A_list, self.vocab))
            print('Output (Modern):', idx_to_sent(fake_B_list, self.vocab))
            print('\n')
            print('Input (Modern):', idx_to_sent(real_B_list, self.vocab))
            print('Output (Shakespeare):', idx_to_sent(fake_A_list,
                                                       self.vocab))
            print('\n')

    def optimize_parameters(self):
        self.set_requires_grad([self.D_A, self.D_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

        self.set_requires_grad([self.D_A, self.D_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_B()
        self.backward_D_A()
        self.optimizer_D.step()

    def update_label_sizes(self, real, rec, real_label):

        if rec.size(0) > real.size(0):
            real_label = torch.cat(
                (real_label, torch.zeros((rec.size(0) - real.size(0))).type(
                    torch.LongTensor).cuda()), 0)
        elif rec.size(0) < real.size(0):
            diff = real.size(0) - rec.size(0)
            to_concat = torch.zeros((diff, self.vocab_size)).cuda()
            to_concat[:, 0] = 1
            rec = torch.cat((rec, to_concat), 0)

        return real, rec, real_label

    def indices_to_one_hot(self, idx_tensor):
        one_hot_tensor = torch.empty((idx_tensor.size(0), self.vocab_size))
        for idx in range(idx_tensor.size(0)):
            zeros = torch.zeros((self.vocab_size))
            zeros[idx_tensor[idx].item()] = 1.0
            one_hot_tensor[idx] = zeros

        return one_hot_tensor

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def clip_gradient(self, model):
        nn.utils.clip_grad_norm_(model.parameters(), 0.25)
Пример #11
0
def main(args):
    with open(args.params, "r") as f:
        params = json.load(f)

    generator = Generator(params["dim_latent"])
    discriminator = Discriminator()

    if args.device is not None:
        generator = generator.cuda(args.device)
        discriminator = discriminator.cuda(args.device)

    # dataloading
    train_dataset = datasets.MNIST(root=args.datadir,
                                   transform=transforms.ToTensor(),
                                   train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=params["batch_size"],
                              num_workers=4,
                              shuffle=True)

    # optimizer
    betas = (params["beta_1"], params["beta_2"])
    optimizer_G = optim.Adam(generator.parameters(),
                             lr=params["learning_rate"],
                             betas=betas)
    optimizer_D = optim.Adam(discriminator.parameters(),
                             lr=params["learning_rate"],
                             betas=betas)

    if not os.path.exists(args.modeldir): os.mkdir(args.modeldir)
    if not os.path.exists(args.logdir): os.mkdir(args.logdir)
    writer = SummaryWriter(args.logdir)

    steps_per_epoch = len(train_loader)

    msg = ["\t{0}: {1}".format(key, val) for key, val in params.items()]
    print("hyperparameters: \n" + "\n".join(msg))

    # main training loop
    for n in range(params["num_epochs"]):
        loader = iter(train_loader)

        print("epoch: {0}/{1}".format(n + 1, params["num_epochs"]))
        for i in tqdm.trange(steps_per_epoch):
            batch, _ = next(loader)
            if args.device is not None: batch = batch.cuda(args.device)

            loss_D = update_discriminator(batch, discriminator, generator,
                                          optimizer_D, params)
            loss_G = update_generator(discriminator, generator, optimizer_G,
                                      params, args.device)

            writer.add_scalar("loss_discriminator/train", loss_D,
                              i + n * steps_per_epoch)
            writer.add_scalar("loss_generator/train", loss_G,
                              i + n * steps_per_epoch)

        torch.save(generator.state_dict(),
                   args.o + ".generator." + str(n) + ".tmp")
        torch.save(discriminator.state_dict(),
                   args.o + ".discriminator." + str(n) + ".tmp")

        # eval
        with torch.no_grad():
            latent = torch.randn(args.num_fake_samples_eval,
                                 params["dim_latent"]).cuda()
            imgs_fake = generator(latent)

            writer.add_images("generated fake images", imgs_fake, n)
            del latent, imgs_fake

    writer.close()

    torch.save(generator.state_dict(), args.o + ".generator.pt")
    torch.save(discriminator.state_dict(), args.o + ".discriminator.pt")
Пример #12
0
class MNISTTrainer:

    def __init__(self, args):
        self.args = args
        # To write to Tensorboard
        self.writer = SummaryWriter()

        # Holds all data classes needed
        self.data = GANData(args, root='./data')

        # Instantiate models and load to device
        self.D = Discriminator(784, args.d_hidden_size).to(args.device)
        self.G = Generator(784, args.latent_size, args.g_hidden_size).to(args.device)

        # Instantiate criterion used for both D and G
        self.criterion = nn.BCELoss()

        # Instantiate an optimizer for both D and G
        self.d_optim = optim.Adam(self.D.parameters(), lr=args.lr)
        self.g_optim = optim.Adam(self.G.parameters(), lr=args.lr)

    def train(self):
        """
        Main training loop for this trainer. To be called in train.py.
        """
        device = self.args.device

        print(f'Training on device: {device}')
        print(f'Beginning training for {self.args.epochs} epochs...')

        for epoch in range(self.args.epochs):

            running_d_loss, running_g_loss = 0.0, 0.0

            for real_imgs, real_labels in tqdm(self.data.real_loader):
                # Load MNIST images and labels to device
                real_imgs, real_labels = real_imgs.to(device), real_labels.to(device)
                # Load latent vectors and labels to device
                z, fake_labels = self.data.sample_latent_space().to(device), self.data.get_fake_labels().to(device)

                #####################################
                #       Update Discriminator        #
                #####################################

                # Get probability scores for real and fake data
                real_logits = self.D(real_imgs)

                fake_imgs = self.G(z)
                fake_logits = self.D(fake_imgs)

                d_real_loss = self.criterion(real_logits, real_labels)
                d_fake_loss = self.criterion(fake_logits, fake_labels)
                d_loss = d_real_loss + d_fake_loss

                # # Backpropagation and update
                self.d_optim.zero_grad()
                d_loss.backward()
                self.d_optim.step()

                #####################################
                #       Update Generator            #
                #####################################

                # Load another batch of latent vectors device
                z = self.data.sample_latent_space(batch_size=len(real_imgs)).to(device)

                # Get generated images and and record loss
                fake_imgs = self.G(z)
                fake_logits = self.D(fake_imgs)
                g_loss = self.criterion(fake_logits, real_labels)

                # Backpropagation and update
                self.g_optim.zero_grad()
                g_loss.backward()
                self.g_optim.step()

                # Keep track of losses and global step
                running_g_loss += g_loss.item()
                running_d_loss += d_loss.item()

            #####################################
            #       Log Info for Epoch          #
            #####################################

            log_str = f"\n{'Completed Epoch:':<20}{epoch + 1:<10}"
            # Value to normalize so we get loss/sample
            norm = len(self.data.mnist_dataset)
            log_str += f"\n{'Discriminator Loss:':<20}{running_d_loss/norm:<10}"
            log_str += f"\n{'Generator Loss:':<20}{running_g_loss/norm:<10}\n"
            print(log_str)

            # Add information to Tensorboard
            self.writer.add_scalar('discriminator_loss', running_d_loss/norm, epoch)
            self.writer.add_scalar('generator_loss', running_g_loss/norm, epoch)

            self.writer.add_scalar('avg_real_logit', t.mean(real_logits).item(), epoch)
            self.writer.add_scalar('avg_fake_logit', t.mean(fake_logits).item(), epoch)

            self.writer.add_scalar('avg_gen_grad', t.mean(self.G.model[0].weight.grad).item(), epoch)
            self.writer.add_scalar('avg_dis_grad', t.mean(self.D.model[0].weight.grad).item(), epoch)

            z = self.data.sample_latent_space(batch_size=36).to(device)
            generated_imgs = self.G(z)

            img_grid = torchvision.utils.make_grid(generated_imgs.reshape(36, 1, 28, 28), nrow=6)
            self.writer.add_image('generated_images', img_grid, epoch)

        # Close tensorboard writer when we're done with training
        self.writer.close()
Пример #13
0
    model_source = VAE2()
else:
    print('Loading source model from {}'.format(args.model_source_path))
    model_source = torch.load(args.model_source_path)
discriminator_model = Discriminator(20, 20)

if args.cuda:
    model_target.cuda()
    model_source.cuda()
    discriminator_model.cuda()

# target_optimizer_encoder_params = [{'params': model_target.fc1.parameters()}, {'params': model_target.fc2.parameters()}]
target_optimizer = optim.Adam(model_target.parameters(), lr=args.lr)
# target_optimizer_encoder = optim.Adam(target_optimizer_encoder_params, lr=args.lr)
source_optimizer = optim.Adam(model_source.parameters(), lr=args.lr)
d_optimizer = optim.Adam(discriminator_model.parameters(), lr=args.lr)

criterion = nn.BCELoss()

if args.source == 'mnist':
    tests = Tests(model_source, model_target, classifyMNIST, 'mnist',
                  'fashionMnist', args, graph)
elif args.source == 'fashionMnist':
    tests = Tests(model_source, model_target, classifyMNIST, 'fashionMnist',
                  'mnist', args, graph)
else:
    raise Exception('args.source does not defined')


def reset_grads():
    model_target.zero_grad()
Пример #14
0
    def train_ei_adv(self,
                     dataloader,
                     physics,
                     transform,
                     epochs,
                     lr,
                     alpha,
                     ckp_interval,
                     schedule,
                     residual=True,
                     pretrained=None,
                     task='',
                     loss_type='l2',
                     cat=True,
                     report_psnr=False,
                     lr_cos=False):
        save_path = './ckp/{}_ei_adv_{}'.format(get_timestamp(), task)

        os.makedirs(save_path, exist_ok=True)

        generator = UNet(in_channels=self.in_channels,
                         out_channels=self.out_channels,
                         compact=4,
                         residual=residual,
                         circular_padding=True,
                         cat=cat)

        if pretrained:
            checkpoint = torch.load(pretrained)
            generator.load_state_dict(checkpoint['state_dict'])

        discriminator = Discriminator(
            (self.in_channels, self.img_width, self.img_height))

        generator = generator.to(self.device)
        discriminator = discriminator.to(self.device)

        if loss_type == 'l2':
            criterion_mc = torch.nn.MSELoss().to(self.device)
            criterion_ei = torch.nn.MSELoss().to(self.device)
        if loss_type == 'l1':
            criterion_mc = torch.nn.L1Loss().to(self.device)
            criterion_ei = torch.nn.L1Loss().to(self.device)

        criterion_gan = torch.nn.MSELoss().to(self.device)

        optimizer_G = Adam(generator.parameters(),
                           lr=lr['G'],
                           weight_decay=lr['WD'])
        optimizer_D = Adam(discriminator.parameters(),
                           lr=lr['D'],
                           weight_decay=0)

        if report_psnr:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=[
                          'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G',
                          'loss_D', 'psnr', 'mse'
                      ])
        else:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=[
                          'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G',
                          'loss_D'
                      ])

        for epoch in range(epochs):
            adjust_learning_rate(optimizer_G, epoch, lr['G'], lr_cos, epochs,
                                 schedule)
            adjust_learning_rate(optimizer_D, epoch, lr['D'], lr_cos, epochs,
                                 schedule)

            loss = closure_ei_adv(generator, discriminator, dataloader,
                                  physics, transform, optimizer_G, optimizer_D,
                                  criterion_mc, criterion_ei, criterion_gan,
                                  alpha, self.dtype, self.device, report_psnr)

            log.record(epoch + 1, *loss)

            if report_psnr:
                print(
                    '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}\tpsnr={:.4f}\tmse={:.4e}'
                    .format(get_timestamp(), epoch, epochs, *loss))
            else:
                print(
                    '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}'
                    .format(get_timestamp(), epoch, epochs, *loss))

            if epoch % ckp_interval == 0 or epoch + 1 == epochs:
                state = {
                    'epoch': epoch,
                    'state_dict_G': generator.state_dict(),
                    'state_dict_D': discriminator.state_dict(),
                    'optimizer_G': optimizer_G.state_dict(),
                    'optimizer_D': optimizer_D.state_dict()
                }
                torch.save(
                    state,
                    os.path.join(save_path, 'ckp_{}.pth.tar'.format(epoch)))
        log.close()
Пример #15
0
class Model(object):
    def __init__(self, opt):
        super(Model, self).__init__()

        # Generator
        self.gen = Generator(opt).cuda(opt.gpu_id)

        self.gen_params = self.gen.parameters()

        num_params = 0
        for p in self.gen.parameters():
            num_params += p.numel()
        print(self.gen)
        print(num_params)

        # Discriminator
        self.dis = Discriminator(opt).cuda(opt.gpu_id)

        self.dis_params = self.dis.parameters()

        num_params = 0
        for p in self.dis.parameters():
            num_params += p.numel()
        print(self.dis)
        print(num_params)

        # Regressor
        if opt.mse_weight:
            self.reg = torch.load('data/utils/classifier.pth').cuda(
                opt.gpu_id).eval()
        else:
            self.reg = None

        # Losses
        self.criterion_gan = GANLoss(opt, self.dis)
        self.criterion_mse = lambda x, y: l1_loss(x, y) * opt.mse_weight

        self.loss_mse = Variable(torch.zeros(1).cuda())
        self.loss_adv = Variable(torch.zeros(1).cuda())
        self.loss = Variable(torch.zeros(1).cuda())

        self.path = opt.experiments_dir + opt.experiment_name + '/checkpoints/'
        self.gpu_id = opt.gpu_id
        self.noise_channels = opt.in_channels - len(opt.input_idx.split(','))

    def forward(self, inputs):

        input, input_orig, target = inputs

        self.input = Variable(input.cuda(self.gpu_id))
        self.input_orig = Variable(input_orig.cuda(self.gpu_id))
        self.target = Variable(target.cuda(self.gpu_id))

        noise = Variable(
            torch.randn(self.input.size(0),
                        self.noise_channels).cuda(self.gpu_id))

        self.fake = self.gen(torch.cat([self.input, noise], 1))

    def backward_G(self):

        # Regressor loss
        if self.reg is not None:

            fake_input = self.reg(self.fake)

            self.loss_mse = self.criterion_mse(fake_input, self.input_orig)

        # GAN loss
        loss_adv, _ = self.criterion_gan(self.fake)

        loss_G = self.loss_mse + loss_adv
        loss_G.backward()

    def backward_D(self):

        loss_adv, self.loss_adv = self.criterion_gan(self.target, self.fake)

        loss_D = loss_adv
        loss_D.backward()

    def train(self):

        self.gen.train()
        self.dis.train()

    def eval(self):

        self.gen.eval()
        self.dis.eval()

    def save_checkpoint(self, epoch):

        torch.save(
            {
                'epoch': epoch,
                'gen_state_dict': self.gen.state_dict(),
                'dis_state_dict': self.dis.state_dict()
            }, self.path + '%d.pkl' % epoch)

    def load_checkpoint(self, path, pretrained=True):

        weights = torch.load(path)

        self.gen.load_state_dict(weights['gen_state_dict'])
        self.dis.load_state_dict(weights['dis_state_dict'])
Пример #16
0
class AdvGAN_Attack:
    def __init__(self, device, model, image_nc, box_min, box_max):
        output_nc = image_nc
        self.device = device
        self.model = model
        self.input_nc = image_nc
        self.output_nc = output_nc

        self.box_min = box_min
        self.box_max = box_max

        self.gen_input_nc = image_nc
        self.netG = Generator(self.gen_input_nc, image_nc).to(device)
        self.netDisc = Discriminator(image_nc).to(device)

        # initialize all weights
        self.netG.apply(weights_init)
        self.netDisc.apply(weights_init)

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.001)
        self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(), lr=0.001)

        if not os.path.exists(models_path):
            os.makedirs(models_path)

    def train_batch(self, x, path, alignment):
        """x is the large not cropped face. TODO find a way to associate image with the image it came from (see if we can do it by filename)"""
        # x is the cropped 256x256 to perturb

        # optimize D
        perturbation = self.netG(x)

        # add a clipping trick
        adv_images = torch.clamp(perturbation, -0.3, 0.3) + x
        adv_images = torch.clamp(adv_images, self.box_min, self.box_max)    # 256 x 256

        original_deepfake = y = ... # TODO load image

        # apply the adversarial image
        protected_image = compose(adv_images, path, alignment)     # TODO: Original image size

        for i in range(1):
            self.optimizer_D.zero_grad()
            pred_real = self.netDisc(x)
            loss_D_real = F.mse_loss(
                pred_real, torch.ones_like(pred_real, device=self.device)
            )
            loss_D_real.backward()

            pred_fake = self.netDisc(adv_images.detach())
            loss_D_fake = F.mse_loss(
                pred_fake, torch.zeros_like(pred_fake, device=self.device)
            )
            loss_D_fake.backward()
            loss_D_GAN = loss_D_fake + loss_D_real
            self.optimizer_D.step()

        # optimize G
        for i in range(1):
            self.optimizer_G.zero_grad()

            # cal G's loss in GAN
            pred_fake = self.netDisc(adv_images)
            loss_G_fake = F.mse_loss(
                pred_fake, torch.ones_like(pred_fake, device=self.device)
            )
            loss_G_fake.backward(retain_graph=True)

            # calculate perturbation norm
            C = 0.1
            loss_perturb = torch.mean(
                torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1)
            )

            # 1 - image similarity
            # TODO apply image back to original image
            # ex. perturbed_original = (adv_images patched onto the original image)
            # Clamp it
            # perform face swap with the images

            # Need to see how it affects the 

            y_ = swapfaces(protected_image)
            norm_similarity = torch.abs(torch.dot(torch.norm(y_, 2), torch.norm(original_deepfake, 2)))
            loss_adv = norm_similarity
            loss_adv.backward() # retain graph

            adv_lambda = 10
            pert_lambda = 1
            loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturb
            loss_G.backward()
            self.optimizer_G.step()

        return (
            loss_D_GAN.item(),
            loss_G_fake.item(),
            loss_perturb.item(),
            loss_adv.item(),
        )

    def train(self, train_dataloader, epochs):
        for epoch in range(1, epochs + 1):

            if epoch == 50:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.0001)
                self.optimizer_D = torch.optim.Adam(
                    self.netDisc.parameters(), lr=0.0001
                )
            if epoch == 80:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.00001)
                self.optimizer_D = torch.optim.Adam(
                    self.netDisc.parameters(), lr=0.00001
                )
            loss_D_sum = 0
            loss_G_fake_sum = 0
            loss_perturb_sum = 0
            loss_adv_sum = 0
            for i, data in enumerate(train_dataloader, start=0):
                (images, _, paths) = data
                images = images.to(self.device)

                (
                    loss_D_batch,
                    loss_G_fake_batch,
                    loss_perturb_batch,
                    loss_adv_batch,
                ) = self.train_batch(images)
                loss_D_sum += loss_D_batch
                loss_G_fake_sum += loss_G_fake_batch
                loss_perturb_sum += loss_perturb_batch
                loss_adv_sum += loss_adv_batch

            # print statistics
            num_batch = len(train_dataloader)
            print(
                "epoch %d:\nloss_D: %.3f, loss_G_fake: %.3f,\
             \nloss_perturb: %.3f, loss_adv: %.3f, \n"
                % (
                    epoch,
                    loss_D_sum / num_batch,
                    loss_G_fake_sum / num_batch,
                    loss_perturb_sum / num_batch,
                    loss_adv_sum / num_batch,
                )
            )

            # save generator
            if epoch % 5 == 0:
                netG_file_name = models_path + "netG_epoch_" + str(epoch) + ".pth"
                torch.save(self.netG.state_dict(), netG_file_name)

                netD_file_name = models_path + "netD_epoch_" + str(epoch) + ".pth"
                torch.save(self.netD.state_dict(), netD_file_name)
Пример #17
0
class GANModelAPI:
    """Класс для упрощенного создания и обучения модели"""
    def __init__(self,
                 files_a,
                 files_b,
                 shift=True,
                 gen_optimizer='Adam',
                 discr_optimizer='Adam',
                 gen_scheduler='default',
                 discr_scheduler='default',
                 step=25,
                 criterion='bceloss',
                 gen_lr=2e-4,
                 discr_lr=2e-4,
                 image_size=256):
        if not torch.cuda.is_available():
            raise BaseException('GPU is not available')
        device = torch.device('cuda')
        if shift:
            self.dataloader1 = DataLoader(ShiftDataset(files_a, image_size),
                                          batch_size=1,
                                          shuffle=True)
            self.dataloader2 = DataLoader(ShiftDataset(files_b, image_size),
                                          batch_size=1,
                                          shuffle=True)
        else:
            self.dataloader = DataLoader(ImgDataset(files_a, files_b,
                                                    image_size),
                                         batch_size=1,
                                         shuffle=True)
        self.generator_a2b = Generator().to(device)
        self.generator_b2a = Generator().to(device)
        self.discriminator_a = Discriminator().to(device)
        self.discriminator_b = Discriminator().to(device)
        if gen_optimizer == 'Adam':
            self.gen_optimizer = optim.Adam(chain(
                self.generator_a2b.parameters(),
                self.generator_b2a.parameters()),
                                            lr=gen_lr)
        elif gen_optimizer == 'AdamW':
            self.gen_optimizer = optim.AdamW(chain(
                self.generator_a2b.parameters(),
                self.generator_b2a.parameters()),
                                             lr=gen_lr)
        else:
            raise NotImplemented(
                f'Optimizer {gen_optimizer} is not supported now')
        if discr_optimizer == 'Adam':
            self.discr_optimizer = optim.Adam(chain(
                self.discriminator_a.parameters(),
                self.discriminator_b.parameters()),
                                              lr=discr_lr)
        elif discr_optimizer == 'AdamW':
            self.discr_optimizer = optim.AdamW(chain(
                self.discriminator_a.parameters(),
                self.discriminator_b.parameters()),
                                               lr=discr_lr)
        else:
            raise NotImplemented(
                f'Optimizer {discr_optimizer} is not supported now')
        if gen_scheduler == 'default':
            self.gen_sched = optim.lr_scheduler.LambdaLR(
                self.gen_optimizer,
                lr_lambda=lambda epoch: (1 / 0.9)**epoch
                if epoch < 5 else (0.9**(epoch - step)
                                   if epoch > step else (1 / 0.9)**5))
        elif gen_scheduler == 'step10warmup':
            self.gen_sched = optim.lr_scheduler.LambdaLR(
                self.gen_optimizer,
                lr_lambda=lambda epoch: (1 / 0.9)**epoch
                if epoch < 5 else (1 / 0.9)**5 * 0.9**((epoch - 5) // 10))
        else:
            raise NotImplemented(
                f'Generators lr scheduler {gen_scheduler} is not supported now'
            )
        if discr_scheduler == 'default':
            self.discr_sched = optim.lr_scheduler.LambdaLR(
                self.discr_optimizer,
                lr_lambda=lambda epoch: (1 / 0.9)**epoch
                if epoch < 5 else (0.9**(epoch - step)
                                   if epoch > step else (1 / 0.9)**5))
        elif discr_scheduler == 'step10warmup':
            self.discr_sched = optim.lr_scheduler.LambdaLR(
                self.discr_optimizer,
                lr_lambda=lambda epoch: (1 / 0.9)**epoch
                if epoch < 5 else (1 / 0.9)**5 * 0.9**((epoch - 5) // 10))
        else:
            raise NotImplemented(
                f'Discriminators lr scheduler {discr_scheduler} is not supported now'
            )
        if criterion == 'bceloss':
            self.criterion = nn.BCELoss()
        else:
            raise NotImplemented(f'Criterion {criterion} is not supported now')
        self.shift = shift

    def train_models(self,
                     max_epochs=200,
                     hold_discr=True,
                     threshold=0.5,
                     intermediate_results=None):
        if self.shift:
            return shift_train(self.generator_a2b,
                               self.generator_b2a,
                               self.discriminator_a,
                               self.discriminator_b,
                               self.gen_optimizer,
                               self.discr_optimizer,
                               self.gen_sched,
                               self.discr_sched,
                               self.criterion,
                               self.dataloader1,
                               self.dataloader2,
                               max_epochs,
                               hold_discr,
                               threshold,
                               intermediate_results=intermediate_results)
        else:
            return train(self.generator_a2b, self.generator_b2a,
                         self.discriminator_a, self.discriminator_b,
                         self.gen_optimizer, self.discr_optimizer,
                         self.gen_sched, self.discr_sched, self.criterion,
                         self.dataloader, max_epochs, hold_discr, threshold,
                         intermediate_results)

    def load_models(self, gen_a2b, gen_b2a, discr_a, discr_b):
        self.generator_a2b = gen_a2b
        self.generator_b2a = gen_b2a
        self.discriminator_a = discr_a
        self.discriminator_b = discr_b

    def save_models(self, mode='torch'):
        if mode == 'torch':
            torch.save(self.generator_a2b, 'gen_a2b.model')
            torch.save(self.generator_b2a, 'gen_b2a.model')
            torch.save(self.discriminator_a, 'discr_a.model')
            torch.save(self.discriminator_b, 'discr_b.model')
        elif mode == 'onnx':
            torch.onnx.export(torch.jit.script(self.generator_a2b),
                              torch.ones((1, 1, 512, 512),
                                         dtype=torch.float32),
                              'gen_a2b.onnx',
                              input_names=['img'])
            torch.onnx.export(torch.jit.script(self.generator_b2a),
                              torch.ones((1, 1, 512, 512),
                                         dtype=torch.float32),
                              'gen_b2a.onnx',
                              input_names=['img'])
        else:
            raise NotImplemented(f'Mode {mode} is not supported ')
Пример #18
0
class WGanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(WGanTrainer, self).__init__(train_loader, test_loader,
                                          valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.generator_time_criterion = nn.MSELoss()

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_time = trainer_args.lambda_time

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Overrides losses from parent class
        self.train_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }
        self.test_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }
        self.valid_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }

        # Select either wgan or wgan-gp method
        self.use_penalty = trainer_args.use_penalty
        self.gamma = trainer_args.gamma_wgan_gp
        self.clipping_limit = trainer_args.clipping_limit
        self.n_critic = trainer_args.n_critic
        self.coupling_epoch = trainer_args.coupling_epoch

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def compute_gradient_penalty(self, input_batch, generated_batch):
        """
        Compute the gradient penalty as described in the original paper
        (https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf).
        :param input_batch: batch of input data (torch tensor).
        :param generated_batch: batch of generated data (torch tensor).
        :return: penalty as a scalar (torch tensor).
        """
        batch_size = input_batch.size(0)
        epsilon = torch.rand(batch_size, 1, 1)
        epsilon = epsilon.expand_as(input_batch).to(self.device)

        # Interpolate
        interpolation = epsilon * input_batch.data + (
            1 - epsilon) * generated_batch.data
        interpolation = interpolation.requires_grad_(True).to(self.device)

        # Computes the discriminator's prediction for the interpolated input
        interpolation_logits = self.discriminator(interpolation)

        # Computes a vector of outputs to make it works with 2 output classes if needed
        grad_outputs = torch.ones_like(interpolation_logits).to(
            self.device).requires_grad_(True)

        # Get the gradients and retain the graph so that the penalty can be back-propagated
        gradients = autograd.grad(outputs=interpolation_logits,
                                  inputs=interpolation,
                                  grad_outputs=grad_outputs,
                                  create_graph=True,
                                  retain_graph=True,
                                  only_inputs=True)[0]
        gradients = gradients.view(batch_size, -1)

        # Computes the norm of the gradients
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1))
        return ((gradients_norm - 1)**2).mean()

    def train_discriminator_step(self, input_batch, target_batch):
        """
        Trains the discriminator for a single step based on the wasserstein gan-gp framework.
        :param input_batch: batch of input data (torch tensor).
        :param target_batch: batch of target data (torch tensor).
        :return: a batch of generated data (torch tensor).
        """
        # Activate gradient tracking for the discriminator
        self.change_discriminator_grad_requirement(requires_grad=True)

        # Set the discriminator's gradients to zero
        self.discriminator_optimizer.zero_grad()

        # Generate a batch and compute the penalty
        generated_batch = self.generator(input_batch)

        # Compute the loss
        loss_d = self.discriminator(generated_batch.detach()).mean(
        ) - self.discriminator(target_batch).mean()
        self.train_losses['discriminator']['adversarial'].append(loss_d.item())
        if self.use_penalty:
            penalty = self.compute_gradient_penalty(input_batch,
                                                    generated_batch.detach())
            self.train_losses['discriminator']['penalty'].append(
                penalty.item())
            loss_d = loss_d + self.gamma * penalty

        # Update the discriminator's weights
        loss_d.backward()
        self.discriminator_optimizer.step()

        # Apply the weight constraint if needed
        if not self.use_penalty:
            for p in self.discriminator.parameters():
                p.data.clamp_(min=-self.clipping_limit,
                              max=self.clipping_limit)

        # Return the generated batch to avoid redundant computation
        return generated_batch

    def train_generator_step(self, target_batch, generated_batch):
        """
        Trains the generator for a single step based on the wasserstein gan-gp framework.
        :param target_batch: batch of target data (torch tensor).
        :param generated_batch: batch of generated data (torch tensor).
        :return: None
        """
        # Deactivate gradient tracking for the discriminator
        self.change_discriminator_grad_requirement(requires_grad=False)

        # Set generator's gradients to zero
        self.generator_optimizer.zero_grad()

        # Get the generator losses
        loss_g_adversarial = -self.discriminator(generated_batch).mean()
        loss_g_time = self.generator_time_criterion(generated_batch,
                                                    target_batch)

        # Combine the different losses
        loss_g = loss_g_time
        if self.epoch >= self.coupling_epoch:
            loss_g = loss_g + self.lambda_adv * loss_g_adversarial

        # Back-propagate and update the generator weights
        loss_g.backward()
        self.generator_optimizer.step()

        # Store the losses
        self.train_losses['generator']['time_l2'].append(loss_g_time.item())
        self.train_losses['generator']['adversarial'].append(
            loss_g_adversarial.item())

    def change_discriminator_grad_requirement(self, requires_grad):
        """
        Changes the requires_grad flag of discriminator's parameters. This action is not absolutely needed as the
        discriminator's optimizer is not called after the generators update, but it reduces the computational cost.
        :param requires_grad: flag indicating if the discriminator's parameter require gradient tracking (boolean).
        :return: None
        """
        for p in self.discriminator.parameters():
            p.requires_grad_(requires_grad)

    def train(self, epochs):
        """
        Trains the WGAN-GP for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        self.generator.train()
        self.discriminator.train()
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Train the discriminator
                generated_batch = self.train_discriminator_step(
                    input_batch, target_batch)

                # Train the generator every n_critic
                if not (i % self.n_critic):
                    self.train_generator_step(target_batch, generated_batch)

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Penalty: {}\n' \
                              '\t\t Adversarial: {} \n'.format(i,
                                                               self.train_losses['generator']['time_l2'][-1],
                                                               self.train_losses['generator']['adversarial'][-1],
                                                               self.train_losses['discriminator']['penalty'][-1],
                                                               self.train_losses['discriminator']['adversarial'][-1])
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        # Set the models in evaluation mode
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            loss_g_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_g_time.item())

        # Store the validation losses
        self.valid_losses['generator']['time_l2'].append(
            np.mean(batch_losses['time_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n'.format(self.epoch, np.mean(np.mean(batch_losses['time_l2'])))
        print(message)

        # Set the models in train mode
        self.generator.train()
        self.discriminator.eval()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        savepath = self.savepath.split('.')[0] + '_' + str(
            self.epoch // 5) + '.' + self.savepath.split('.')[1]
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        # self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        # self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
Пример #19
0
        # f = torch.nn.DataParallel(f).cuda()
        f = f.cuda()

    criterion_gan = nn.MSELoss()
    criterion_l1 = nn.L1Loss()
    criterion_l2 = nn.MSELoss()
    criterion_crossentropy = nn.CrossEntropyLoss()

    if use_gpu:
        criterion_gan.to(device)
        criterion_l1.to(device)
        criterion_l2.to(device)
        criterion_crossentropy.to(device)

    g_optimizer = optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999))

    start_timestamp = int(time.time() * 1000)
    start_epoch = 0
    best_accuracy = 0
    best_loss = 1e100
    global_step = 0

    if args.pre_trained:
        print("Loading a pretrained model ")
        # checkpoint = torch.load(os.path.join(args.checkpoint, 'speechcommand/last-speech-commands-checkpoint_adv_10_classes.pth'))
        # f.load_state_dict(checkpoint['state_dict'])

        checkpoint = torch.load(
            os.path.join(args.checkpoint, 'speechcommand/sampleCNN_49.pth'))
        f.load_state_dict(checkpoint)
Пример #20
0
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Config
    batch_size = 9
    image_size = 256
    learning_rate = 1e-3
    beta1, beta2 = (.5, .99)
    weight_decay = 1e-3
    epochs = 10

    # Models
    netD = Discriminator().to(device)
    netG = Generator().to(device)

    optimizerD = AdamW(netD.parameters(),
                       lr=learning_rate,
                       betas=(beta1, beta2),
                       weight_decay=weight_decay)
    optimizerG = AdamW(netG.parameters(),
                       lr=learning_rate,
                       betas=(beta1, beta2),
                       weight_decay=weight_decay)

    # Labels
    cartoon_labels = torch.ones(batch_size, 1, image_size // 4,
                                image_size // 4).to(device)
    fake_labels = torch.zeros(batch_size, 1, image_size // 4,
                              image_size // 4).to(device)

    # Loss functions
    content_loss = ContentLoss(device)
    adv_loss = AdversialLoss(cartoon_labels, fake_labels)
    BCE_loss = nn.BCELoss().to(device)

    # Dataloaders
    real_dataloader = get_dataloader("./datasets/real_images",
                                     size=image_size,
                                     bs=batch_size)
    cartoon_dataloader = get_dataloader("./datasets/cartoon_images",
                                        size=image_size,
                                        bs=batch_size)
    edge_dataloader = get_dataloader("./datasets/cartoon_images_smooth",
                                     size=image_size,
                                     bs=batch_size)

    # --------------------------------------------------------------------------------------------- #
    # Training Loop

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    tracked_images = next(iter(real_dataloader))[0].to(device)

    print("Starting Training Loop...")
    # For each epoch.
    for epoch in range(epochs):
        # For each batch in the dataloader.
        for i, ((cartoon_data, _), (edge_data, _),
                (real_data, _)) in enumerate(
                    zip(cartoon_dataloader, edge_dataloader, real_dataloader)):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################

            # Reset Discriminator gradient.
            netD.zero_grad()

            # Format batch.
            cartoon_data = cartoon_data.to(device)
            edge_data = edge_data.to(device)
            real_data = real_data.to(device)

            # Generate image
            generated_data = netG(real_data)

            # Forward pass all batches through D.
            cartoon_pred = netD(cartoon_data)  #.view(-1)
            edge_pred = netD(edge_data)  #.view(-1)
            generated_pred = netD(generated_data)  #.view(-1)

            print(generated_data.is_cuda, real_data.is_cuda)

            # Calculate discriminator loss on all batches.
            errD = adv_loss(cartoon_pred, generated_pred, edge_pred)

            # Calculate gradients for D in backward pass
            errD.backward()
            D_x = cartoon_pred.mean().item()  # Should be close to 1

            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################

            # Reset Generator gradient.
            netG.zero_grad()

            # Since we just updated D, perform another forward pass of all-fake batch through D
            generated_pred = netD(generated_data)  #.view(-1)

            # Calculate G's loss based on this output
            print(generated_data.is_cuda, real_data.is_cuda)
            print("generated_pred:", generated_pred.is_cuda, "cartoon_labels:",
                  cartoon_labels.is_cuda)
            errG = BCE_loss(generated_pred, cartoon_labels) + content_loss(
                generated_data, real_data)

            # Calculate gradients for G
            errG.backward()

            D_G_z2 = generated_pred.mean().item()  # Should be close to 1

            # Update G
            optimizerG.step()

            # ---------------------------------------------------------------------------------------- #

            # Output training stats
            if i % 50 == 0:
                print(
                    '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, epochs, i, len(real_dataloader), errD.item(),
                       errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on tracked_images
            if (iters % 500 == 0) or ((epoch == epochs - 1) and
                                      (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = netG(tracked_images).detach().cpu()
                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1
Пример #21
0
save_freq = config['log']['save_freq']
save_everything = config['log']['save_everything']
path =config['train']['folder']

device = config['train']['device']
batch_size = config['train']['batch_size']
num_workers = config['train']['num_workers']
resolution = config['train']['resolution']
n_epochs = config['train']['epochs']

lr = config['optimizer']['lr']
beta1 = config['optimizer']['beta_1']
beta2 = config['optimizer']['beta_2']
amsgrad = config['optimizer']['amsgrad']




trainloader = DataSampler.build(path, batch_size, num_workers, resolution)

generator = Generator(resolution).to(device)
optim_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))


discriminator = Discriminator(resolution).to(device)
optim_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))



trainer(generator, discriminator, optim_g, optim_d, trainloader, n_epochs, device, log_interval, logging_dir, save_freq, checkpoint_dir, resolution, num_samples, save_everything)
Пример #22
0
                                           'stage1_test/loc.csv'),
                              transform=T.ToTensor(),
                              remove_alpha=True)
"""
-----------------
----- Model -----
-----------------
"""

generator = UNet(3, 1)
discriminator = Discriminator(4, 1)
generator.cuda()
discriminator.cuda()
# lr = 0.001 seems to work WITHOUT PRETRAINING
g_optim = optim.Adam(generator.parameters(), lr=0.001)
d_optim = optim.Adam(discriminator.parameters(), lr=0.001)
#g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(g_optim, factor=0.1, verbose=True, patience=5)
#d_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(d_optim, factor=0.1, verbose=True, patience=5)

gan = GAN(
    g=generator,
    d=discriminator,
    g_optim=g_optim,
    d_optim=d_optim,
    g_loss=nn.MSELoss().cuda(),
    d_loss=nn.MSELoss().cuda(),
    #g_scheduler=g_scheduler, d_scheduler=d_scheduler
)

# pretraining generator
"""
Пример #23
0
def train():
    torch.manual_seed(1337)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Config
    batch_size = 32
    image_size = 256
    learning_rate = 1e-4
    beta1, beta2 = (.5, .99)
    weight_decay = 1e-4
    epochs = 1000

    # Models
    netD = Discriminator().to(device)
    netG = Generator().to(device)
    # Here you should load the pretrained G
    netG.load_state_dict(torch.load("./checkpoints/pretrained_netG.pth").state_dict())

    optimizerD = AdamW(netD.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)
    optimizerG = AdamW(netG.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)

    scaler = torch.cuda.amp.GradScaler()

    # Labels
    cartoon_labels = torch.ones (batch_size, 1, image_size // 4, image_size // 4).to(device)
    fake_labels    = torch.zeros(batch_size, 1, image_size // 4, image_size // 4).to(device)

    # Loss functions
    content_loss = ContentLoss().to(device)
    adv_loss     = AdversialLoss(cartoon_labels, fake_labels).to(device)
    BCE_loss     = nn.BCEWithLogitsLoss().to(device)

    # Dataloaders
    real_dataloader    = get_dataloader("./datasets/real_images/flickr30k_images/",           size = image_size, bs = batch_size)
    cartoon_dataloader = get_dataloader("./datasets/cartoon_images_smoothed/Studio Ghibli",   size = image_size, bs = batch_size, trfs=get_pair_transforms(image_size))

    # --------------------------------------------------------------------------------------------- #
    # Training Loop

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    tracked_images = next(iter(real_dataloader)).to(device)

    print("Starting Training Loop...")
    # For each epoch.
    for epoch in range(epochs):
        print("training epoch ", epoch)
        # For each batch in the dataloader.
        for i, (cartoon_edge_data, real_data) in enumerate(zip(cartoon_dataloader, real_dataloader)):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            
            # Reset Discriminator gradient.
            netD.zero_grad()
            for param in netD.parameters():
                param.requires_grad = True

            # Format batch.
            cartoon_data   = cartoon_edge_data[:, :, :, :image_size].to(device)
            edge_data      = cartoon_edge_data[:, :, :, image_size:].to(device)
            real_data      = real_data.to(device)

            with torch.cuda.amp.autocast():
                # Generate image
                generated_data = netG(real_data)

                # Forward pass all batches through D.
                cartoon_pred   = netD(cartoon_data)      #.view(-1)
                edge_pred      = netD(edge_data)         #.view(-1)
                generated_pred = netD(generated_data.detach())    #.view(-1)

                # Calculate discriminator loss on all batches.
                errD = adv_loss(cartoon_pred, generated_pred, edge_pred)
            
            # Calculate gradients for D in backward pass
            scaler.scale(errD).backward()
            D_x = cartoon_pred.mean().item() # Should be close to 1

            # Update D
            scaler.step(optimizerD)


            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            
            # Reset Generator gradient.
            netG.zero_grad()
            for param in netD.parameters():
                param.requires_grad = False

            with torch.cuda.amp.autocast():
                # Since we just updated D, perform another forward pass of all-fake batch through D
                generated_pred = netD(generated_data) #.view(-1)

                # Calculate G's loss based on this output
                errG = BCE_loss(generated_pred, cartoon_labels) + content_loss(generated_data, real_data)

            # Calculate gradients for G
            scaler.scale(errG).backward()

            D_G_z2 = generated_pred.mean().item() # Should be close to 1
            
            # Update G
            scaler.step(optimizerG)

            scaler.update()
            
            # ---------------------------------------------------------------------------------------- #

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on tracked_images
            if iters % 200 == 0:
                with torch.no_grad():
                    fake = netG(tracked_images)
                vutils.save_image(unnormalize(fake), f"images/{epoch}_{i}.png", padding=2)
                with open("images/log.txt", "a+") as f:
                    f.write(f"{datetime.now().isoformat(' ', 'seconds')}\tD: {np.mean(D_losses)}\tG: {np.mean(G_losses)}\n")
                D_losses = []
                G_losses = []

            if iters % 1000 == 0:
                torch.save(netG.state_dict(), f"checkpoints/netG_e{epoch}_i{iters}_l{errG.item()}.pth")
                torch.save(netD.state_dict(), f"checkpoints/netD_e{epoch}_i{iters}_l{errG.item()}.pth")

            iters += 1
Пример #24
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        self.noise_n = config.noise_n
        self.G_last_act = last_act(config.G_last_act)
        self.D_out_n = config.D_out_n
        self.D_last_act = last_act(config.D_last_act)

        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.D_train_step = config.D_train_step
        self.save_image_step = config.save_image_step
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step
        self.clip_value = config.clip_value
        self.lambda_gp = config.lambda_gp

        self.model_save_path = config.model_save_path
        self.log_save_path = config.log_save_path
        self.image_save_path = config.image_save_path

        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.build_model()

        if self.use_tensorboard is not None:
            self.build_tensorboard()
        if self.pretrained_model is not None:
            if len(self.pretrained_model) != 2:
                raise "must have both G and D pretrained parameters, and G is first, D is second"
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.noise_n, self.G_last_act)
        self.D = Discriminator(self.D_out_n, self.D_last_act)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr,
                                            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def build_tensorboard(self):
        from commons.logger import Logger
        self.logger = Logger(self.log_save_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(self.pretrained_model[0]))
        self.D.load_state_dict(torch.load(self.pretrained_model[1]))

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def train(self):
        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                label = torch.FloatTensor(batch_size)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = -torch.mean(real_out)
                D_fake = torch.mean(fake_out)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # choose one in below two
                # Clip weights of D
                # for p in self.D.parameters():
                #     p.data.clamp_(-self.clip_value, clip_value)
                # Gradients penalty, WGAP-GP
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                # print(alpha.shape, real_x.shape, fake_x.shape)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                gp_out = self.D(interpolated)
                grad = torch.autograd.grad(outputs=gp_out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               gp_out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)
                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.D_optimizer.step()
                # Train G
                if (i + 1) % self.D_train_step == 0:
                    fake_out = self.D(self.G(noise_x))
                    G_loss = -torch.mean(fake_out)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(16, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    self.denorm(fake_image.data),
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
Пример #25
0
class PGAgent():
    def __init__(self):
        
        self.env = gym.make('CartPole-v1')
        self.expert_traj_fpath = os.path.join('GAIL', 'expert_trajectories', 'expert_traj_test.npz')
        self.save_policy_fpath = os.path.join('GAIL','gail_actor1.pt')
        self.save_rewards_fig = os.path.join('GAIL', 'gail_rewards.png')
        #
        self.state_space = 4
        self.action_space = 2

        # These values are taken from the env's state and action sizes
        self.actor_net = ActorNet(self.state_space, self.action_space)
        self.actor_net.to(device=Device)
        self.actor_optim = torch.optim.Adam(self.actor_net.parameters(), lr = 0.0001)

        self.critic_net = CriticNet(self.state_space)
        self.critic_net.to(Device)
        self.critic_net_optim = torch.optim.Adam(self.critic_net.parameters(), lr = 0.0001)

        self.discriminator = Discriminator(self.state_space, self.action_space)
        self.discriminator.to(Device)
        self.discriminator_optim = torch.optim.Adam(self.discriminator.parameters(), lr = 0.0001)

        # Storing all the values used to calculate the model losses
        self.traj_obs = []
        self.traj_actions = []
        self.traj_rewards = []
        self.traj_dones = []
        self.traj_logprobs = []
        self.traj_logits = []
        self.traj_state_values = []

        # Discount factor
        self.gamma = 0.95
        # Bias Variance tradeoff (higher value results in high variance, low bias)
        self.gae_lambda = 0.95

        # These two will be used during the training of the policy
        self.ppo_batch_size = 500
        self.ppo_epochs = 12
        self.ppo_eps = 0.2

        # Discriminator
        self.num_expert_transitions = 150

        # These will be used for the agent to play using the current policy
        self.max_eps = 5
        # Max steps in mountaincar ex is 200
        self.max_steps = 800

        # documenting the stats
        self.avg_over = 5 # episodes
        self.stats = {'episode': 0, 'ep_rew': []}

    def clear_lists(self):
        self.traj_obs = []
        self.traj_actions = []
        self.traj_rewards = []
        self.traj_dones = []
        self.traj_logprobs = []
        self.traj_logits = []
        self.traj_state_values = []

    # This returns a categorical torch object, which makes it easier to calculate log_prob, prob and sampling from them
    def get_logits(self, state):
        logits = self.actor_net(state)
        return Categorical(logits=logits)

    def calc_policy_loss(self, states, actions, rewards):
        assert (torch.is_tensor(states) and torch.is_tensor(actions) and torch.is_tensor(rewards)),\
             "states and actions are not in the right format"

        # The negative sign is for gradient ascent
        loss = -(self.get_logits(states).log_prob(actions))*rewards
        return loss.mean()

    def ppo_calc_log_prob(self, states, actions):
        obs_tensor = torch.as_tensor(states).float().to(device=Device)
        actions = torch.as_tensor(actions).float().to(device=Device)
        logits = self.get_logits(obs_tensor)
        entropy = logits.entropy()
        log_prob = logits.log_prob(actions)

        return log_prob, entropy

    def get_action(self, state):
        
        # Finding the logits and state value using the actor and critic net
        logits = self.get_logits(state)
        action = logits.sample()

        # Sample in categorical finds probability first and then samples values according to that prob
        return action.item()


        # This gives the reward to go for each transition in the batch
        rew_to_go_list = []
        rew_sum = 0
        for rew, done in zip(reversed(traj_rewards), reversed(traj_dones)):
            if done:
                rew_sum = rew
                rew_to_go_list.append(rew_sum)
            else:
                rew_sum = rew + rew_sum
                rew_to_go_list.append(rew_sum)

        rew_to_go_list = reversed(rew_to_go_list)
        return list(rew_to_go_list)

    # This returns the concatenated state_action tensor used to input to discriminator
    # obs and actions are a list
    def concat_state_action(self, obs_list, actions_list, shuffle=False):
        obs = np.array(obs_list)
        actions_data = np.array(actions_list)
        actions = np.zeros((len(actions_list), self.action_space))
        actions[np.arange(len(actions_list)), actions_data] = 1  # Converting to one hot encoding

        state_action = np.concatenate((obs, actions), axis=1)
        if shuffle:
            np.random.shuffle(state_action)  # Shuffling to break any coorelations

        state_action = torch.as_tensor(state_action).float().to(Device)

        return state_action

    # This uses the discriminator and critic networks to calculate the advantage 
    # of each state action pair, and the targets for the critic network
    def calc_gae_targets(self):

        obs_tensor = torch.as_tensor(self.traj_obs).float().to(Device)
        action_tensor = torch.as_tensor(self.traj_actions).float().to(Device)
        state_action = self.concat_state_action(self.traj_obs, self.traj_actions)

        # This calculates how well we have fooled the discriminator, which is equivalent to the 
        # reward at each time step
        disc_rewards = -torch.log(self.discriminator(state_action))
        disc_rewards = disc_rewards.view(-1).tolist()

        traj_state_values = self.critic_net(obs_tensor).view(-1).tolist()
        gae = []
        targets = []

        for i, val, next_val, reward, done in \
        zip(range(len(self.traj_dones)), \
        reversed(traj_state_values), \
        reversed(traj_state_values[1:] + [None]), \
        reversed(disc_rewards), \
        reversed(self.traj_dones)):

            # last trajectory maybe cut short because we have a limit on max_steps,
            # so last done may not always be True
            if done or i==0: 
                delta = reward - val
                last_gae = delta
            else:
                delta = reward + self.gamma*next_val - val
                last_gae = delta + self.gamma*self.gae_lambda*last_gae
            
            gae.append(last_gae)
            targets.append(last_gae + val)

        return list(reversed(gae)), list(reversed(targets))

    # We use num_transitions samples from expert trajectory, 
    # and all policy trajectory transitions to train discriminator.
    def update_discriminator(self, num_transitions):

        # data stores the expert trajectories used to train the discriminator
        # data is a dictionary with keys: 'obs', 'acts', 'rews', 'dones'
        data = np.load(self.expert_traj_fpath)
        # TODO: Using data is this way blocks a lot of memory. It would be much more efficient to load the numpy
        # array using a generator
        obs = data['obs']
        actions = data['acts']

        # Zeroing Discriminator gradient
        self.discriminator_optim.zero_grad()
        loss = nn.BCELoss()

        ## Sampling from expert trajectories. 
        random_samples_ind = np.random.choice(len(obs), num_transitions) 
        expert_state_action = self.concat_state_action(obs[random_samples_ind], actions[random_samples_ind])

        ## Expert Loss, target for expert trajectory taken 0
        expert_output = self.discriminator(expert_state_action)
        expert_traj_loss = loss(expert_output, torch.zeros((num_transitions, 1), device=Device))
        
        ## Sampling policy trajectories
        policy_state_action = self.concat_state_action(self.traj_obs, self.traj_actions, shuffle=True)

        ## Policy Traj loss, target for policy trajectory taken 1
        policy_traj_output = self.discriminator(policy_state_action)
        policy_traj_loss = loss(policy_traj_output, torch.ones((policy_traj_output.shape[0], 1), device=Device))

        # Updating the Discriminator
        D_loss = expert_traj_loss + policy_traj_loss
        D_loss.backward()
        self.discriminator_optim.step()

    def train_gail(self):

        assert (len(self.traj_obs)==len(self.traj_actions)==len(self.traj_dones)), "Size of traj lists don't match"

        # We use self.num_expert_transitions and all saved policy transitions from play()
        self.update_discriminator(self.num_expert_transitions)
        
        # Finding old log prob.
        # If the number of transistions are too large, this could also be broken down and calculated in batches
        # This uses the traj_states and traj_actions to calculate the log_probs of each action
        with torch.no_grad():
            old_logprob, _ = self.ppo_calc_log_prob(self.traj_obs, self.traj_actions)
            old_logprob.detach()

            traj_gae, traj_targets = self.calc_gae_targets()

        # Performing ppo policy updates in batches
        for epoch in range(self.ppo_epochs):
            for batch_offs in range(0, len(self.traj_dones), self.ppo_batch_size):
                batch_obs = self.traj_obs[batch_offs:batch_offs + self.ppo_batch_size]
                batch_actions = self.traj_actions[batch_offs:batch_offs + self.ppo_batch_size]
                batch_gae = traj_gae[batch_offs:batch_offs + self.ppo_batch_size]
                batch_targets = traj_targets[batch_offs:batch_offs + self.ppo_batch_size]
                batch_old_logprob = old_logprob[batch_offs:batch_offs + self.ppo_batch_size]

                # Zero the gradients
                self.actor_optim.zero_grad()
                self.critic_net_optim.zero_grad()

                # Critic Loss
                batch_obs_tensor = torch.as_tensor(batch_obs).float().to(Device)
                state_vals = self.critic_net(batch_obs_tensor).view(-1)
                batch_targets = torch.as_tensor(batch_targets).float().to(Device)
                critic_loss = F.mse_loss(state_vals, batch_targets)
                
                # Policy and Entropy Loss
                log_prob, entropy = self.ppo_calc_log_prob(batch_obs, batch_actions)
                batch_ratio = torch.exp(log_prob - batch_old_logprob)
                batch_gae = torch.as_tensor(batch_gae).float().to(Device)
                unclipped_objective = batch_ratio * batch_gae
                clipped_objective = torch.clamp(batch_ratio, 1 - self.ppo_eps, 1 + self.ppo_eps) * batch_gae
                policy_loss = -torch.min(clipped_objective, unclipped_objective).mean()
                entropy_loss = -entropy.mean()

                # Performing backprop
                critic_loss.backward()
                # Here both policy_loss and entropy_loss calculate grad values in the actor net.
                # By using retain_graph, the next backward call will add onto the previous grad values.
                policy_loss.backward(retain_graph=True)
                entropy_loss.backward()

                # print('Losses: ', (critic_loss.shape, policy_loss.shape, entropy_loss.shape))
                # print('critic grad values: ', self.critic_net.fc1.weight.grad)
                # print('actor grad values: ', self.actor_net.fc1.weight.grad)

                # Updating the networks
                self.actor_optim.step()
                self.critic_net_optim.step()

    # The agent will play self.max_eps episodes using the current policy, and train on that data
    def play(self, rendering):
        
        self.clear_lists()
        saved_transitions = 0
        for ep in range(self.max_eps):
            obs = self.env.reset()
            ep_reward = 0

            for step in range(self.max_steps):
                
                if rendering==True:
                    self.env.render()

                self.traj_obs.append(obs)
                obs = torch.from_numpy(obs).float().to(device=Device)
                
                # get_action() will run obs through actor network and find the action to take
                action = self.get_action(obs)
                
                # We are saving the reward here, but this will not be used in the optimization of the policy
                # or discriminator, it is only used to track our progress.
                obs, rew, done, info = self.env.step(action)
                ep_reward += rew

                self.traj_actions.append(action)
                self.traj_rewards.append(rew)

                saved_transitions += 1

                if done:
                    # We will not save the last observation, since it is essentially a dead state
                    # This will result in having the same length of obs, action, reward and dones deque
                    self.traj_dones.append(done)
                    self.stats['ep_rew'].append(ep_reward)
                    self.stats['episode'] += 1
                    break

                else:
                    self.traj_dones.append(done)
            # print(f" {ep} episodes over.", end='\r')
            print('episode over. Reward: ', ep_reward)

            
        self.train_gail()


    def run(self, model_path, policy_iterations = 65, show_renders_every = 20, renders = True):
        for i in range(policy_iterations):
            if i%show_renders_every==0:
                self.play(rendering=renders)
            else:
                self.play(rendering=False)
            print(f" Policy updated {i} times")
        
        torch.save(self.actor_net.state_dict(), model_path)
        print('model saved at: ', model_path)

    def plot_rewards(self, avg_over=10):       
        graph_x = np.arange(self.stats['episode'])
        graph_y = np.array(self.stats['ep_rew'])

        assert (len(graph_x) == len(graph_y)), "Plot axes do not match"

        graph_x_averaged = [mean(arr) for arr in np.array_split(graph_x, len(graph_x)/avg_over)]
        graph_y_averaged = [mean(arr) for arr in np.array_split(graph_y, len(graph_y)/avg_over)]

        plt.plot(graph_x_averaged, graph_y_averaged)
        plt.savefig(self.save_rewards_fig)
Пример #26
0
classifier = Classifier()
critic = Discriminator(input_dims=params.d_input_dims,
                                     hidden_dims=params.d_hidden_dims,
                                     output_dims=params.d_output_dims)
generator = Generator()

criterion = nn.CrossEntropyLoss()

# special for target
generator_larger = Generator_Larger()

optimizer_c = optim.Adam(
    classifier.parameters(), lr=params.learning_rate, betas=(params.beta1, params.beta2)
)
optimizer_d = optim.Adam(
    critic.parameters(), lr=params.learning_rate, betas=(params.beta1, params.beta2)
)
optimizer_g = optim.Adam(
    generator.parameters(), lr=params.learning_rate, betas=(params.beta1, params.beta2)
)

optimizer_g_l = optim.Adam(
    generator_larger.parameters(), lr=params.learning_rate, betas=(params.beta1, params.beta2)
)
data_itr_src = get_data_iter("MNIST", train=True)
data_itr_tgt = get_data_iter("USPS", train=True)

pos_labels = Variable(torch.Tensor([1]))
neg_lables = Variable(torch.Tensor([-1]))
g_step = 0
Пример #27
0
        if i % config.make_img_samples == 0:
            for x in range(5):
                make_img_samples(G)


if __name__ == '__main__':
    dataset = CelebADataset()

    dataloader = get_dataloader(dataset)

    G = Generator(config.latent_size).to(config.device)
    D = Discriminator().to(config.device)

    optim_G = torch.optim.AdamW(G.parameters(),
                                lr=config.lr,
                                betas=(0.5, 0.999))
    optim_D = torch.optim.AdamW(D.parameters(),
                                lr=config.lr,
                                betas=(0.5, 0.999))

    if (config.continue_training):
        G, optim_G, D, optim_D = load_models_with_optims(
            G, optim_G, D, optim_D, config.train_model_path, config.device)
    else:
        G.apply(weights_init)
        D.apply(weights_init)

    criterion = nn.CrossEntropyLoss()

    train(dataloader, D, G, optim_D, optim_G, criterion)
Пример #28
0
class VaeGanModule(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        # Encoder
        self.encoder = Encoder(ngf=self.hparams.ngf, z_dim=self.hparams.z_dim)
        self.encoder.apply(weights_init)
        device = "cuda" if isinstance(self.hparams.gpus, int) else "cpu"
        # Decoder
        self.decoder = Decoder(ngf=self.hparams.ngf, z_dim=self.hparams.z_dim)
        self.decoder.apply(weights_init)
        # Discriminator
        self.discriminator = Discriminator()
        self.discriminator.apply(weights_init)

        # Losses
        self.criterionFeat = torch.nn.L1Loss()
        self.criterionGAN = GANLoss(gan_mode="lsgan")

        if self.hparams.use_vgg:
            self.criterion_perceptual_style = [Perceptual_Loss(device)]

    @staticmethod
    def reparameterize(mu, logvar, mode='train'):
        if mode == 'train':
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu

    def discriminate(self, fake_image, real_image):
        input_concat_fake = torch.cat(
            (fake_image.detach(), real_image),
            dim=1)  # non sono sicuro che .detach() sia necessario in lightning
        input_concat_real = torch.cat((real_image, real_image), dim=1)

        return (self.discriminator(input_concat_fake),
                self.discriminator(input_concat_real))

    def training_step(self, batch, batch_idx, optimizer_idx):
        x, _ = batch

        # train VAE
        if optimizer_idx == 0:

            # encode
            mu, log_var = self.encoder(x)
            z_repar = VaeGanModule.reparameterize(mu, log_var)

            # decode
            fake_image = self.decoder(z_repar)

            # reconstruction
            reconstruction_loss = self.criterionFeat(fake_image, x)
            kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) -
                                         log_var.exp())

            # Discriminate
            input_concat_fake = torch.cat((fake_image, x), dim=1)
            pred_fake = self.discriminator(input_concat_fake)

            # Losses
            loss_G_GAN = self.criterionGAN(pred_fake, True)
            if self.hparams.use_vgg:
                loss_G_perceptual = \
                    self.criterion_perceptual_style[0](fake_image, x)
            else:
                loss_G_perceptual = 0.0
            g_loss = (reconstruction_loss *
                      20) + kld_loss + loss_G_GAN + loss_G_perceptual

            # Results are collected in a TrainResult object
            result = pl.TrainResult(g_loss)
            result.log("rec_loss", reconstruction_loss * 10, prog_bar=True)
            result.log("loss_G_GAN", loss_G_GAN, prog_bar=True)
            result.log("kld_loss", kld_loss, prog_bar=True)
            result.log("loss_G_perceptual", loss_G_perceptual, prog_bar=True)

        # train Discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # Encode
            mu, log_var = self.encoder(x)
            z_repar = VaeGanModule.reparameterize(mu, log_var)

            # Decode
            fake_image = self.decoder(z_repar)

            # how well can it label as real?
            pred_fake, pred_real = self.discriminate(fake_image, x)

            # Fake loss
            d_loss_fake = self.criterionGAN(pred_fake, False)

            # Real Loss
            d_loss_real = self.criterionGAN(pred_real, True)

            # Total loss is average of prediction of fakes and reals
            loss_D = (d_loss_fake + d_loss_real) / 2

            # Results are collected in a TrainResult object
            result = pl.TrainResult(loss_D)
            result.log("loss_D_real", d_loss_real, prog_bar=True)
            result.log("loss_D_fake", d_loss_fake, prog_bar=True)

        return result

    def training_epoch_end(self, training_step_outputs):
        z_appr = torch.normal(mean=0,
                              std=1,
                              size=(16, self.hparams.z_dim),
                              device=training_step_outputs[0].minimize.device)

        # Generate images from latent vector
        sample_imgs = self.decoder(z_appr)
        grid = torchvision.utils.make_grid(sample_imgs,
                                           normalize=True,
                                           range=(-1, 1))

        # where to save the image
        path = os.path.join(self.hparams.generated_images_folder,
                            f"generated_images_{self.current_epoch}.png")
        torchvision.utils.save_image(sample_imgs,
                                     path,
                                     normalize=True,
                                     range=(-1, 1))

        # Log images in tensorboard
        self.logger.experiment.add_image(f'generated_images', grid,
                                         self.current_epoch)

        # Epoch level metrics
        epoch_loss = torch.mean(
            torch.stack([x['minimize'] for x in training_step_outputs]))
        results = pl.TrainResult()
        results.log("epoch_loss", epoch_loss, prog_bar=False)

        return results

    def validation_step(self, batch, batch_idx):
        x, _ = batch

        # Encode
        mu, log_var = self.encoder(x)
        z_repar = VaeGanModule.reparameterize(mu, log_var)

        # Decode
        recons = self.decoder(z_repar)
        reconstruction_loss = nn.functional.mse_loss(recons, x)

        # Results are collected in a EvalResult object
        result = pl.EvalResult(checkpoint_on=reconstruction_loss)
        return result

    testing_step = validation_step

    def configure_optimizers(self):
        params_vae = concat_generators(self.encoder.parameters(),
                                       self.decoder.parameters())
        opt_vae = torch.optim.Adam(params_vae,
                                   lr=self.hparams.learning_rate_vae)

        parameters_discriminator = self.discriminator.parameters()
        opt_d = torch.optim.Adam(parameters_discriminator,
                                 lr=self.hparams.learning_rate_d)

        return [opt_vae, opt_d]

    @staticmethod
    def add_argparse_args(parser):

        parser.add_argument('--generated_images_folder',
                            required=False,
                            default="./output",
                            type=str)
        parser.add_argument('--ngf', type=int, default=128)
        parser.add_argument('--z_dim', type=int, default=128)
        parser.add_argument('--learning_rate_vae',
                            default=1e-03,
                            required=False,
                            type=float)
        parser.add_argument('--learning_rate_d',
                            default=1e-03,
                            required=False,
                            type=float)
        parser.add_argument("--use_vgg", action="store_true", default=False)

        return parser
Пример #29
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        self.noise_n = config.noise_n
        self.G_last_act = last_act(config.G_last_act)
        self.D_out_n = config.D_out_n
        self.D_last_act = last_act(config.D_last_act)

        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.D_train_step = config.D_train_step
        self.save_image_step = config.save_image_step
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step

        self.model_save_path = config.model_save_path
        self.log_save_path = config.log_save_path
        self.image_save_path = config.image_save_path

        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.build_model()

        if self.use_tensorboard is not None:
            self.build_tensorboard()
        if self.pretrained_model is not None:
            if len(self.pretrained_model) != 2:
                raise "must have both G and D pretrained parameters, and G is first, D is second"
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.noise_n, self.G_last_act)
        self.D = Discriminator(self.D_out_n, self.D_last_act)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr,
                                            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def build_tensorboard(self):
        from commons.logger import Logger
        self.logger = Logger(self.log_save_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(self.pretrained_model[0]))
        self.D.load_state_dict(torch.load(self.pretrained_model[1]))

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def train(self):
        bce_loss = nn.BCELoss()

        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                real_label = self.to_var(
                    torch.FloatTensor(batch_size).fill_(1.))
                fake_label = self.to_var(
                    torch.FloatTensor(batch_size).fill_(0.))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = bce_loss(real_out, real_label)
                D_fake = bce_loss(fake_out, fake_label)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # Train G
                if (i + 1) % self.D_train_step == 0:
                    # noise_x = self.to_var(torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                    fake_out = self.D(self.G(noise_x))
                    G_loss = bce_loss(fake_out, real_label)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(32, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    fake_image.data,
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
Пример #30
0
    netD = nn.DataParallel(netD, list(range(ngpu)))

netG.apply(weight_init)
netD.apply(weight_init)

print(netG)
print(netD)

criterion = nn.BCELoss()

fixed_noise = torch.randn(64, z, 1, 1, device=device)

real_label = 1
fake_label = 0

optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))

D_loss = []
G_loss = []
image_list = []
iters = 0
# TODO : Remove the training epochs and replace with training steps

# Training Loop
for epoch_idx in range(num_epochs):
    for i, data in enumerate(mnist_loader, 0):
        ##############################
        ###  Disriminator training ###
        ##############################