Ejemplo n.º 1
0
def get_modules(opt):
    modules = {}
    disc = Discriminator()
    gen = Generator()
    clf = Classifier()
    if opt.cuda:
        disc = disc.cuda()
        gen = gen.cuda()
        clf = clf.cuda()

    modules['Discriminator'] = disc
    modules['Generator'] = gen
    modules['Classifier'] = clf
    return modules
Ejemplo n.º 2
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        self.noise_n = config.noise_n
        self.G_last_act = last_act(config.G_last_act)
        self.D_out_n = config.D_out_n
        self.D_last_act = last_act(config.D_last_act)

        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.D_train_step = config.D_train_step
        self.save_image_step = config.save_image_step
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step

        self.model_save_path = config.model_save_path
        self.log_save_path = config.log_save_path
        self.image_save_path = config.image_save_path

        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.build_model()

        if self.use_tensorboard is not None:
            self.build_tensorboard()
        if self.pretrained_model is not None:
            if len(self.pretrained_model) != 2:
                raise "must have both G and D pretrained parameters, and G is first, D is second"
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.noise_n, self.G_last_act)
        self.D = Discriminator(self.D_out_n, self.D_last_act)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr,
                                            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def build_tensorboard(self):
        from commons.logger import Logger
        self.logger = Logger(self.log_save_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(self.pretrained_model[0]))
        self.D.load_state_dict(torch.load(self.pretrained_model[1]))

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def train(self):
        bce_loss = nn.BCELoss()

        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                real_label = self.to_var(
                    torch.FloatTensor(batch_size).fill_(1.))
                fake_label = self.to_var(
                    torch.FloatTensor(batch_size).fill_(0.))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = bce_loss(real_out, real_label)
                D_fake = bce_loss(fake_out, fake_label)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # Train G
                if (i + 1) % self.D_train_step == 0:
                    # noise_x = self.to_var(torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                    fake_out = self.D(self.G(noise_x))
                    G_loss = bce_loss(fake_out, real_label)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(32, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    fake_image.data,
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
Ejemplo n.º 3
0
class ModuleTrain:
    def __init__(self, opt, best_loss=0.2):
        self.opt = opt
        self.best_loss = best_loss  # 正确率这个值,才会保存模型

        self.netd = Discriminator(self.opt)
        self.netg = Generator(self.opt)
        self.use_gpu = False

        # 加载模型
        if os.path.exists(self.opt.netd_path):
            self.load_netd(self.opt.netd_path)
        else:
            print('[Load model] error: %s not exist !!!' % self.opt.netd_path)
        if os.path.exists(self.opt.netg_path):
            self.load_netg(self.opt.netg_path)
        else:
            print('[Load model] error: %s not exist !!!' % self.opt.netg_path)

        # DataLoader初始化
        self.transform_train = T.Compose([
            T.Resize((self.opt.img_size, self.opt.img_size)),
            T.ToTensor(),
            T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]),
        ])
        train_dataset = ImageFolder(root=self.opt.data_path,
                                    transform=self.transform_train)
        self.train_loader = DataLoader(dataset=train_dataset,
                                       batch_size=self.opt.batch_size,
                                       shuffle=True,
                                       num_workers=self.opt.num_workers,
                                       drop_last=True)

        # 优化器和损失函数
        # self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.5)
        self.optimizer_g = optim.Adam(self.netg.parameters(),
                                      lr=self.opt.lr1,
                                      betas=(self.opt.beta1, 0.999))
        self.optimizer_d = optim.Adam(self.netd.parameters(),
                                      lr=self.opt.lr2,
                                      betas=(self.opt.beta1, 0.999))
        self.criterion = torch.nn.BCELoss()

        self.true_labels = Variable(torch.ones(self.opt.batch_size))
        self.fake_labels = Variable(torch.zeros(self.opt.batch_size))
        self.fix_noises = Variable(
            torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
        self.noises = Variable(
            torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))

        # gpu or cpu
        if self.opt.use_gpu and torch.cuda.is_available():
            self.use_gpu = True
        else:
            self.use_gpu = False
        if self.use_gpu:
            print('[use gpu] ...')
            self.netd.cuda()
            self.netg.cuda()
            self.criterion.cuda()
            self.true_labels = self.true_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.fix_noises = self.fix_noises.cuda()
            self.noises = self.noises.cuda()
        else:
            print('[use cpu] ...')

        pass

    def train(self, save_best=True):
        print('[train] epoch: %d' % self.opt.max_epoch)
        for epoch_i in range(self.opt.max_epoch):
            loss_netd = 0.0
            loss_netg = 0.0
            correct = 0

            print('================================================')
            for ii, (img, target) in enumerate(self.train_loader):  # 训练
                real_img = Variable(img)
                if self.opt.use_gpu:
                    real_img = real_img.cuda()

                # 训练判别器
                if (ii + 1) % self.opt.d_every == 0:
                    self.optimizer_d.zero_grad()
                    # 尽可能把真图片判别为1
                    output = self.netd(real_img)
                    error_d_real = self.criterion(output, self.true_labels)
                    error_d_real.backward()

                    # 尽可能把假图片判别为0
                    self.noises.data.copy_(
                        torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
                    fake_img = self.netg(self.noises).detach()  # 根据噪声生成假图
                    fake_output = self.netd(fake_img)
                    error_d_fake = self.criterion(fake_output,
                                                  self.fake_labels)
                    error_d_fake.backward()
                    self.optimizer_d.step()

                    loss_netd += (error_d_real.item() + error_d_fake.item())

                # 训练生成器
                if (ii + 1) % self.opt.g_every == 0:
                    self.optimizer_g.zero_grad()
                    self.noises.data.copy_(
                        torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
                    fake_img = self.netg(self.noises)
                    fake_output = self.netd(fake_img)
                    # 尽可能让判别器把假图片也判别为1
                    error_g = self.criterion(fake_output, self.true_labels)
                    error_g.backward()
                    self.optimizer_g.step()

                    loss_netg += error_g

            loss_netd /= (len(self.train_loader) * 2)
            loss_netg /= len(self.train_loader)
            print('[Train] Epoch: {} \tNetD Loss: {:.6f} \tNetG Loss: {:.6f}'.
                  format(epoch_i, loss_netd, loss_netg))
            if save_best is True:
                if (loss_netg + loss_netd) / 2 < self.best_loss:
                    self.best_loss = (loss_netg + loss_netd) / 2
                    self.save(self.netd, self.opt.best_netd_path)  # 保存最好的模型
                    self.save(self.netg, self.opt.best_netg_path)  # 保存最好的模型
                    print('[save best] ...')

            # self.vis()

            if (epoch_i + 1) % 5 == 0:
                self.image_gan()

        self.save(self.netd, self.opt.netd_path)  # 保存最好的模型
        self.save(self.netg, self.opt.netg_path)  # 保存最好的模型

    def vis(self):
        fix_fake_imgs = self.netg(self.opt.fix_noises)
        visdom.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                      win='fixfake')

    def image_gan(self):
        noises = torch.randn(self.opt.gen_search_num, self.opt.nz, 1,
                             1).normal_(self.opt.gen_mean, self.opt.gen_std)
        with torch.no_grad():
            noises = Variable(noises)

        if self.use_gpu:
            noises = noises.cuda()

        fake_img = self.netg(noises)
        scores = self.netd(fake_img).data
        indexs = scores.topk(self.opt.gen_num)[1]
        result = list()
        for ii in indexs:
            result.append(fake_img.data[ii])

        torchvision.utils.save_image(torch.stack(result),
                                     self.opt.gen_img,
                                     normalize=True,
                                     range=(-1, 1))

        #     # print(correct)
        #     # print(len(self.train_loader.dataset))
        #     train_loss /= len(self.train_loader)
        #     acc = float(correct) / float(len(self.train_loader.dataset))
        #     print('[Train] Epoch: {} \tLoss: {:.6f}\tAcc: {:.6f}\tlr: {}'.format(epoch_i, train_loss, acc, self.lr))
        #
        #     test_acc = self.test()
        #     if save_best is True:
        #         if test_acc > self.best_acc:
        #             self.best_acc = test_acc
        #             str_list = self.model_file.split('.')
        #             best_model_file = ""
        #             for str_index in range(len(str_list)):
        #                 best_model_file = best_model_file + str_list[str_index]
        #                 if str_index == (len(str_list) - 2):
        #                     best_model_file += '_best'
        #                 if str_index != (len(str_list) - 1):
        #                     best_model_file += '.'
        #             self.save(best_model_file)                                  # 保存最好的模型
        #
        # self.save(self.model_file)

    def test(self):
        test_loss = 0.0
        correct = 0

        time_start = time.time()
        # 测试集
        for data, target in self.test_loader:
            data, target = Variable(data), Variable(target)

            if self.use_gpu:
                data = data.cuda()
                target = target.cuda()

            output = self.model(data)
            # sum up batch loss
            if self.use_gpu:
                loss = self.loss(output, target)
            else:
                loss = self.loss(output, target)
            test_loss += loss.item()

            predict = torch.argmax(output, 1)
            correct += (predict == target).sum().data

        time_end = time.time()
        time_avg = float(time_end - time_start) / float(
            len(self.test_loader.dataset))
        test_loss /= len(self.test_loader)
        acc = float(correct) / float(len(self.test_loader.dataset))

        print('[Test] set: Test loss: {:.6f}\t Acc: {:.6f}\t time: {:.6f} \n'.
              format(test_loss, acc, time_avg))
        return acc

    def load_netd(self, name):
        print('[Load model netd] %s ...' % name)
        self.netd.load_state_dict(torch.load(name))

    def load_netg(self, name):
        print('[Load model netg] %s ...' % name)
        self.netg.load_state_dict(torch.load(name))

    def save(self, model, name):
        print('[Save model] %s ...' % name)
        torch.save(model.state_dict(), name)
Ejemplo n.º 4
0
class Train(object):
    """
    Main GAN trainer. Responsible for training the GAN and pre-training the generator autoencoder.
    """
    def __init__(self, config):
        """
        Construct a new GAN trainer
        :param Config config: The parsed network configuration.
        """
        self.config = config

        LOG.info("CUDA version: {0}".format(version.cuda))
        LOG.info("Creating data loader from path {0}".format(config.FILENAME))

        self.data_loader = Data(
            config.FILENAME,
            config.BATCH_SIZE,
            polarisations=config.POLARISATIONS,  # Polarisations to use
            frequencies=config.FREQUENCIES,  # Frequencies to use
            max_inputs=config.
            MAX_SAMPLES,  # Max inputs per polarisation and frequency
            normalise=config.NORMALISE)  # Normalise inputs

        shape = self.data_loader.get_input_shape()
        width = shape[1]
        LOG.info("Creating models with input shape {0}".format(shape))
        self._autoencoder = Autoencoder(width)
        self._discriminator = Discriminator(width)
        # TODO: Get correct input and output widths for generator
        self._generator = Generator(width, width)

        if config.USE_CUDA:
            LOG.info("Using CUDA")
            self.autoencoder = self._autoencoder.cuda()
            self.discriminator = self._discriminator.cuda()
            self.generator = self._generator.cuda()
        else:
            LOG.info("Using CPU")
            self.autoencoder = self._autoencoder
            self.discriminator = self._discriminator
            self.generator = self._generator

    def check_requeue(self, epochs_complete):
        """
        Check and re-queue the training script if it has completed the desired number of training epochs per session
        :param int epochs_complete: Number of epochs completed
        :return: True if the script has been requeued, False if not
        :rtype bool
        """
        if self.config.REQUEUE_EPOCHS > 0:
            if epochs_complete >= self.config.REQUEUE_EPOCHS:
                # We've completed enough epochs for this instance. We need to kill it and requeue
                LOG.info(
                    "REQUEUE_EPOCHS of {0} met, calling REQUEUE_SCRIPT".format(
                        self.config.REQUEUE_EPOCHS))
                subprocess.call(self.config.REQUEUE_SCRIPT,
                                shell=True,
                                cwd=os.path.dirname(
                                    self.config.REQUEUE_SCRIPT))
                return True  # Requeue performed
        return False  # No requeue needed

    def load_state(self, checkpoint, module, optimiser=None):
        """
        Load the provided checkpoint into the provided module and optimiser.
        This function checks whether the load threw an exception and logs it to the user.
        :param Checkpoint checkpoint: The checkpoint to load
        :param module: The pytorch module to load the checkpoint into.
        :param optimiser: The pytorch optimiser to load the checkpoint into.
        :return: None if the load failed, int number of epochs in the checkpoint if load succeeded
        """
        try:
            module.load_state_dict(checkpoint.module_state)
            if optimiser is not None:
                optimiser.load_state_dict(checkpoint.optimiser_state)
            return checkpoint.epoch
        except RuntimeError as e:
            LOG.exception(
                "Error loading module state. This is most likely an input size mismatch. Please delete the old module saved state, or change the input size"
            )
            return None

    def close(self):
        """
        Close the data loader used by the trainer.
        """
        self.data_loader.close()

    def generate_labels(self, num_samples, pattern):
        """
        Generate labels for the discriminator.
        :param int num_samples: Number of input samples to generate labels for.
        :param list pattern: Pattern to generator. Should be either [1, 0], or [0, 1]
        :return: New labels for the discriminator
        """
        var = torch.FloatTensor([pattern] * num_samples)
        return var.cuda() if self.config.USE_CUDA else var

    def _train_autoencoder(self):
        """
        Main training loop for the autencoder.
        This function will return False if:
        - Loading the autoencoder succeeded, but the NN model did not load the state dicts correctly.
        - The script needs to be re-queued because the NN has been trained for REQUEUE_EPOCHS
        :return: True if training was completed, False if training needs to continue.
        :rtype bool
        """

        criterion = nn.SmoothL1Loss()

        optimiser = optim.Adam(self.generator.parameters(),
                               lr=0.00003,
                               betas=(0.5, 0.999))
        checkpoint = Checkpoint("autoencoder")
        epoch = 0
        if checkpoint.load():
            epoch = self.load_state(checkpoint, self.autoencoder, optimiser)
            if epoch is not None and epoch >= self.config.MAX_AUTOENCODER_EPOCHS:
                LOG.info("Autoencoder already trained")
                return True
            else:
                LOG.info(
                    "Autoencoder training beginning from epoch {0}".format(
                        epoch))
        else:
            LOG.info('Autoencoder checkpoint not found. Training from start')

        # Train autoencoder
        self._autoencoder.set_mode(Autoencoder.Mode.AUTOENCODER)

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "autoencoder",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            epochs_complete = 0
            while epoch < self.config.MAX_AUTOENCODER_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return False  # Requeue needed and training not complete

                for step, (data, _, _) in enumerate(self.data_loader):
                    if self.config.USE_CUDA:
                        data = data.cuda()

                    if self.config.ADD_DROPOUT:
                        # Drop out parts of the input, but compute loss on the full input.
                        out = self.autoencoder(nn.functional.dropout(
                            data, 0.5))
                    else:
                        out = self.autoencoder(data)

                    loss = criterion(out.cpu(), data.cpu())
                    self.autoencoder.zero_grad()
                    loss.backward()
                    optimiser.step()

                    vis.step_autoencoder(loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], loss: {4:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1,
                                   self.config.MAX_AUTOENCODER_EPOCHS, step,
                                   len(self.data_loader), loss))

                epoch += 1
                epochs_complete += 1

                checkpoint.set(self.autoencoder.state_dict(),
                               optimiser.state_dict(), epoch).save()

                LOG.info("Plotting autoencoder progress")
                vis.plot_training(epoch)
                data, _, _ = iter(self.data_loader).__next__()
                vis.test_autoencoder(epoch, self.autoencoder, data.cuda())

        LOG.info("Autoencoder training complete")
        return True  # Training complete

    def _train_gan(self):
        """
        TODO: Add in autoencoder to perform dimensionality reduction on data
        TODO: Not working yet - trying to work out good autoencoder model first
        :return:
        """

        criterion = nn.BCELoss()

        discriminator_optimiser = optim.Adam(self.discriminator.parameters(),
                                             lr=0.003,
                                             betas=(0.5, 0.999))
        discriminator_scheduler = optim.lr_scheduler.LambdaLR(
            discriminator_optimiser, lambda epoch: 0.97**epoch)
        discriminator_checkpoint = Checkpoint("discriminator")
        discriminator_epoch = 0
        if discriminator_checkpoint.load():
            discriminator_epoch = self.load_state(discriminator_checkpoint,
                                                  self.discriminator,
                                                  discriminator_optimiser)
        else:
            LOG.info('Discriminator checkpoint not found')

        generator_optimiser = optim.Adam(self.generator.parameters(),
                                         lr=0.003,
                                         betas=(0.5, 0.999))
        generator_scheduler = optim.lr_scheduler.LambdaLR(
            generator_optimiser, lambda epoch: 0.97**epoch)
        generator_checkpoint = Checkpoint("generator")
        generator_epoch = 0
        if generator_checkpoint.load():
            generator_epoch = self.load_state(generator_checkpoint,
                                              self.generator,
                                              generator_optimiser)
        else:
            LOG.info('Generator checkpoint not found')

        if discriminator_epoch is None or generator_epoch is None:
            epoch = 0
            LOG.info(
                "Discriminator or generator failed to load, training from start"
            )
        else:
            epoch = min(generator_epoch, discriminator_epoch)
            LOG.info("Generator loaded at epoch {0}".format(generator_epoch))
            LOG.info("Discriminator loaded at epoch {0}".format(
                discriminator_epoch))
            LOG.info("Training from lowest epoch {0}".format(epoch))

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "gan",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            real_labels = None  # all 1s
            fake_labels = None  # all 0s
            epochs_complete = 0
            while epoch < self.config.MAX_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return  # Requeue needed and training not complete

                for step, (data, noise1,
                           noise2) in enumerate(self.data_loader):
                    batch_size = data.size(0)
                    if real_labels is None or real_labels.size(
                            0) != batch_size:
                        real_labels = self.generate_labels(batch_size, [1.0])
                    if fake_labels is None or fake_labels.size(
                            0) != batch_size:
                        fake_labels = self.generate_labels(batch_size, [0.0])

                    if self.config.USE_CUDA:
                        data = data.cuda()
                        noise1 = noise1.cuda()
                        noise2 = noise2.cuda()

                    # ============= Train the discriminator =============
                    # Pass real noise through first - ideally the discriminator will return 1 #[1, 0]
                    d_output_real = self.discriminator(data)
                    # Pass generated noise through - ideally the discriminator will return 0 #[0, 1]
                    d_output_fake1 = self.discriminator(self.generator(noise1))

                    # Determine the loss of the discriminator by adding up the real and fake loss and backpropagate
                    d_loss_real = criterion(
                        d_output_real, real_labels
                    )  # How good the discriminator is on real input
                    d_loss_fake = criterion(
                        d_output_fake1, fake_labels
                    )  # How good the discriminator is on fake input
                    d_loss = d_loss_real + d_loss_fake
                    self.discriminator.zero_grad()
                    d_loss.backward()
                    discriminator_optimiser.step()

                    # =============== Train the generator ===============
                    # Pass in fake noise to the generator and get it to generate "real" noise
                    # Judge how good this noise is with the discriminator
                    d_output_fake2 = self.discriminator(self.generator(noise2))

                    # Determine the loss of the generator using the discriminator and backpropagate
                    g_loss = criterion(d_output_fake2, real_labels)
                    self.discriminator.zero_grad()
                    self.generator.zero_grad()
                    g_loss.backward()
                    generator_optimiser.step()

                    vis.step(d_loss_real.item(), d_loss_fake.item(),
                             g_loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], d_loss_real: {4:.4f}, d_loss_fake: {5:.4f}, g_loss: {6:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1, self.config.MAX_EPOCHS, step + 1,
                                   len(self.data_loader), d_loss_real,
                                   d_loss_fake, g_loss))

                epoch += 1
                epochs_complete += 1

                discriminator_checkpoint.set(
                    self.discriminator.state_dict(),
                    discriminator_optimiser.state_dict(), epoch).save()
                generator_checkpoint.set(self.generator.state_dict(),
                                         generator_optimiser.state_dict(),
                                         epoch).save()
                vis.plot_training(epoch)

                data, noise1, _ = iter(self.data_loader).__next__()
                if self.config.USE_CUDA:
                    data = data.cuda()
                    noise1 = noise1.cuda()
                vis.test(epoch, self.data_loader.get_input_size_first(),
                         self.discriminator, self.generator, noise1, data)

                generator_scheduler.step(epoch)
                discriminator_scheduler.step(epoch)

                LOG.info("Learning rates: d {0} g {1}".format(
                    discriminator_optimiser.param_groups[0]["lr"],
                    generator_optimiser.param_groups[0]["lr"]))

        LOG.info("GAN Training complete")

    def __call__(self):
        """
        Main training loop for the GAN.
        The training process is interruptable; the model and optimiser states are saved to disk each epoch, and the
        latest states are restored when the trainer is resumed.

        If the script is not able to load the generator's saved state, it will attempt to load the pre-trained generator
        autoencoder from the generator_decoder_complete checkpoint (if it exists). If this also fails, the generator is
        pre-trained as an autoencoder. This training is also interruptable, and will produce the
        generator_decoder_complete checkpoint on completion.

        On successfully restoring generator and discriminator state, the trainer will proceed from the earliest restored
        epoch. For example, if the generator is restored from epoch 7 and the discriminator is restored from epoch 5,
        training will proceed from epoch 5.

        Visualisation plots are produces each epoch and stored in
        /path_to_input_file_directory/{gan/generator_auto_encoder}/{timestamp}/{epoch}

        Each time the trainer is run, it creates a new timestamp directory using the current time.
        """

        # Load the autoencoder, and train it if needed.
        if not self._train_autoencoder():
            # Autoencoder training incomplete
            return
    def decode(self, z):
        h3 = self.decoder(z)
        return h3

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        # z = z.view(args.batch_size, z_dim, 1, 1)
        z = z.view(args.batch_size, z_dim)
        return self.decode(z), mu, logvar

model = VAE()
discrim = Discriminator()
if use_cuda:
    model.cuda()
    discrim.cuda()
optimizer = optim.Adam(model.parameters(), lr = args.lr)
discrim_optimizer = optim.Adam(discrim.parameters(), lr = args.discrim_lr)

# Reconstruction + KL divergence losses summed over all elements and batch
A, B, C = 224, 224, 3
image_size = A * B * C
def loss_function(recon_x, x, mu, logvar):
    # BCE = F.binary_cross_entropy(recon_x.view(-1, image_size), x.view(-1, image_size), size_average=False)
    # BCE = F.binary_cross_entropy(recon_x, x)
    ## define the GAN loss here
    label = Variable(torch.ones(args.batch_size).type(Tensor))
    ## TODO: see if its a variable
    discrim_output = discrim(recon_x)
    BCE = F.binary_cross_entropy(discrim_output, label)
Ejemplo n.º 6
0
data_itr_src = get_data_iter("MNIST", train=True)
data_itr_tgt = get_data_iter("USPS", train=True)

pos_labels = Variable(torch.Tensor([1]))
neg_lables = Variable(torch.Tensor([-1]))
g_step = 0

g_loss_durations = []
d_loss_durations = []
c_loss_durations = []

# take variable into cuda
if use_cuda:
    generator.cuda()
    generator_larger.cuda()
    critic.cuda()
    classifier.cuda()
    pos_labels = pos_labels.cuda()
    neg_lables = neg_lables.cuda()

# for 循环
for epoch in range(params.num_epochs):
    # break
    # 训练 鉴别器
    # 开启求 鉴别器的梯度
    for p in critic.parameters():
        p.requires_grad = True
    # 设置 鉴别器的训练步数
    if g_step < 25 or g_step % 500 == 0:
        # this helps to start with the critic at optimum
        # even in the first iterations.
Ejemplo n.º 7
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        self.noise_n = config.noise_n
        self.G_last_act = last_act(config.G_last_act)
        self.D_out_n = config.D_out_n
        self.D_last_act = last_act(config.D_last_act)

        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.D_train_step = config.D_train_step
        self.save_image_step = config.save_image_step
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step
        self.clip_value = config.clip_value
        self.lambda_gp = config.lambda_gp

        self.model_save_path = config.model_save_path
        self.log_save_path = config.log_save_path
        self.image_save_path = config.image_save_path

        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.build_model()

        if self.use_tensorboard is not None:
            self.build_tensorboard()
        if self.pretrained_model is not None:
            if len(self.pretrained_model) != 2:
                raise "must have both G and D pretrained parameters, and G is first, D is second"
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.noise_n, self.G_last_act)
        self.D = Discriminator(self.D_out_n, self.D_last_act)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr,
                                            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def build_tensorboard(self):
        from commons.logger import Logger
        self.logger = Logger(self.log_save_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(self.pretrained_model[0]))
        self.D.load_state_dict(torch.load(self.pretrained_model[1]))

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def train(self):
        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                label = torch.FloatTensor(batch_size)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = -torch.mean(real_out)
                D_fake = torch.mean(fake_out)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # choose one in below two
                # Clip weights of D
                # for p in self.D.parameters():
                #     p.data.clamp_(-self.clip_value, clip_value)
                # Gradients penalty, WGAP-GP
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                # print(alpha.shape, real_x.shape, fake_x.shape)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                gp_out = self.D(interpolated)
                grad = torch.autograd.grad(outputs=gp_out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               gp_out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)
                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.D_optimizer.step()
                # Train G
                if (i + 1) % self.D_train_step == 0:
                    fake_out = self.D(self.G(noise_x))
                    G_loss = -torch.mean(fake_out)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(16, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    self.denorm(fake_image.data),
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
Ejemplo n.º 8
0
                                         transform=T.ToTensor(),
                                         remove_alpha=True)
test_dataset = TestFromFolder(os.path.join(all_datasets,
                                           'stage1_test/loc.csv'),
                              transform=T.ToTensor(),
                              remove_alpha=True)
"""
-----------------
----- Model -----
-----------------
"""

generator = UNet(3, 1)
discriminator = Discriminator(4, 1)
generator.cuda()
discriminator.cuda()
# lr = 0.001 seems to work WITHOUT PRETRAINING
g_optim = optim.Adam(generator.parameters(), lr=0.001)
d_optim = optim.Adam(discriminator.parameters(), lr=0.001)
#g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(g_optim, factor=0.1, verbose=True, patience=5)
#d_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(d_optim, factor=0.1, verbose=True, patience=5)

gan = GAN(
    g=generator,
    d=discriminator,
    g_optim=g_optim,
    d_optim=d_optim,
    g_loss=nn.MSELoss().cuda(),
    d_loss=nn.MSELoss().cuda(),
    #g_scheduler=g_scheduler, d_scheduler=d_scheduler
)
Ejemplo n.º 9
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=400,
                        help='number of epochs of training')
    parser.add_argument('--batchSize',
                        type=int,
                        default=10,
                        help='size of the batches')
    parser.add_argument('--dataroot',
                        type=str,
                        default='datasets/genderchange/',
                        help='root directory of the dataset')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument(
        '--decay_epoch',
        type=int,
        default=100,
        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size',
                        type=int,
                        default=256,
                        help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc',
                        type=int,
                        default=3,
                        help='number of channels of input data')
    parser.add_argument('--output_nc',
                        type=int,
                        default=3,
                        help='number of channels of output data')
    parser.add_argument('--cuda',
                        action='store_true',
                        help='use GPU computation')
    parser.add_argument(
        '--n_cpu',
        type=int,
        default=8,
        help='number of cpu threads to use during batch generation')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    ###### Definition of variables ######
    # Networks
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    if opt.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Dataset loader
    transforms_ = [
        transforms.Resize(int(opt.size * 1.2), Image.BICUBIC),
        transforms.CenterCrop(opt.size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=opt.n_cpu,
                            drop_last=True)

    # Plot Loss and Images in Tensorboard
    experiment_dir = 'logs/{}@{}'.format(
        opt.dataroot.split('/')[1],
        datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
    os.makedirs(experiment_dir, exist_ok=True)
    writer = SummaryWriter(os.path.join(experiment_dir, "tb"))

    metric_dict = defaultdict(list)
    n_iters_total = 0

    ###################################
    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(
                same_B, real_B) * 5.0  # [batchSize, 3, ImgSize, ImgSize]

            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(
                same_A, real_A) * 5.0  # [batchSize, 3, ImgSize, ImgSize]

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B).view(-1)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)  # [batchSize]

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A).view(-1)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)  # [batchSize]

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(
                recovered_A, real_A) * 10.0  # [batchSize, 3, ImgSize, ImgSize]

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(
                recovered_B, real_B) * 10.0  # [batchSize, 3, ImgSize, ImgSize]

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

            loss_G.backward()
            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A).view(-1)
            loss_D_real = criterion_GAN(pred_real, target_real)  # [batchSize]

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach()).view(-1)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)  # [batchSize]

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B).view(-1)
            loss_D_real = criterion_GAN(pred_real, target_real)  # [batchSize]

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach()).view(-1)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)  # [batchSize]

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            metric_dict['loss_G'].append(loss_G.item())
            metric_dict['loss_G_identity'].append(loss_identity_A.item() +
                                                  loss_identity_B.item())
            metric_dict['loss_G_GAN'].append(loss_GAN_A2B.item() +
                                             loss_GAN_B2A.item())
            metric_dict['loss_G_cycle'].append(loss_cycle_ABA.item() +
                                               loss_cycle_BAB.item())
            metric_dict['loss_D'].append(loss_D_A.item() + loss_D_B.item())

            for title, value in metric_dict.items():
                writer.add_scalar('train/{}'.format(title), value[-1],
                                  n_iters_total)

            n_iters_total += 1

        print("""
        -----------------------------------------------------------
        Epoch : {} Finished
        Loss_G : {}
        Loss_G_identity : {}
        Loss_G_GAN : {}
        Loss_G_cycle : {}
        Loss_D : {}
        -----------------------------------------------------------
        """.format(epoch, loss_G, loss_identity_A + loss_identity_B,
                   loss_GAN_A2B + loss_GAN_B2A,
                   loss_cycle_ABA + loss_cycle_BAB, loss_D_A + loss_D_B))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints

        if loss_G.item() < 2.5:
            os.makedirs(os.path.join(experiment_dir, str(epoch)),
                        exist_ok=True)
            torch.save(netG_A2B.state_dict(),
                       '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch))
            torch.save(netG_B2A.state_dict(),
                       '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch))
            torch.save(netD_A.state_dict(),
                       '{}/{}/netD_A.pth'.format(experiment_dir, epoch))
            torch.save(netD_B.state_dict(),
                       '{}/{}/netD_B.pth'.format(experiment_dir, epoch))
        elif epoch > 100 and epoch % 40 == 0:
            os.makedirs(os.path.join(experiment_dir, str(epoch)),
                        exist_ok=True)
            torch.save(netG_A2B.state_dict(),
                       '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch))
            torch.save(netG_B2A.state_dict(),
                       '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch))
            torch.save(netD_A.state_dict(),
                       '{}/{}/netD_A.pth'.format(experiment_dir, epoch))
            torch.save(netD_B.state_dict(),
                       '{}/{}/netD_B.pth'.format(experiment_dir, epoch))

        for title, value in metric_dict.items():
            writer.add_scalar("train/{}_epoch".format(title), np.mean(value),
                              epoch)
Ejemplo n.º 10
0
def main(args):
    with open(args.params, "r") as f:
        params = json.load(f)

    generator = Generator(params["dim_latent"])
    discriminator = Discriminator()

    if args.device is not None:
        generator = generator.cuda(args.device)
        discriminator = discriminator.cuda(args.device)

    # dataloading
    train_dataset = datasets.MNIST(root=args.datadir,
                                   transform=transforms.ToTensor(),
                                   train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=params["batch_size"],
                              num_workers=4,
                              shuffle=True)

    # optimizer
    betas = (params["beta_1"], params["beta_2"])
    optimizer_G = optim.Adam(generator.parameters(),
                             lr=params["learning_rate"],
                             betas=betas)
    optimizer_D = optim.Adam(discriminator.parameters(),
                             lr=params["learning_rate"],
                             betas=betas)

    if not os.path.exists(args.modeldir): os.mkdir(args.modeldir)
    if not os.path.exists(args.logdir): os.mkdir(args.logdir)
    writer = SummaryWriter(args.logdir)

    steps_per_epoch = len(train_loader)

    msg = ["\t{0}: {1}".format(key, val) for key, val in params.items()]
    print("hyperparameters: \n" + "\n".join(msg))

    # main training loop
    for n in range(params["num_epochs"]):
        loader = iter(train_loader)

        print("epoch: {0}/{1}".format(n + 1, params["num_epochs"]))
        for i in tqdm.trange(steps_per_epoch):
            batch, _ = next(loader)
            if args.device is not None: batch = batch.cuda(args.device)

            loss_D = update_discriminator(batch, discriminator, generator,
                                          optimizer_D, params)
            loss_G = update_generator(discriminator, generator, optimizer_G,
                                      params, args.device)

            writer.add_scalar("loss_discriminator/train", loss_D,
                              i + n * steps_per_epoch)
            writer.add_scalar("loss_generator/train", loss_G,
                              i + n * steps_per_epoch)

        torch.save(generator.state_dict(),
                   args.o + ".generator." + str(n) + ".tmp")
        torch.save(discriminator.state_dict(),
                   args.o + ".discriminator." + str(n) + ".tmp")

        # eval
        with torch.no_grad():
            latent = torch.randn(args.num_fake_samples_eval,
                                 params["dim_latent"]).cuda()
            imgs_fake = generator(latent)

            writer.add_images("generated fake images", imgs_fake, n)
            del latent, imgs_fake

    writer.close()

    torch.save(generator.state_dict(), args.o + ".generator.pt")
    torch.save(discriminator.state_dict(), args.o + ".discriminator.pt")
Ejemplo n.º 11
0
else:
    print('Loading target model from {}'.format(args.model_target_path))
    model_target = torch.load(args.model_target_path)

if not (args.resume and os.path.isfile(args.model_source_path)):
    print('Creating new source model')
    model_source = VAE2()
else:
    print('Loading source model from {}'.format(args.model_source_path))
    model_source = torch.load(args.model_source_path)
discriminator_model = Discriminator(20, 20)

if args.cuda:
    model_target.cuda()
    model_source.cuda()
    discriminator_model.cuda()

# target_optimizer_encoder_params = [{'params': model_target.fc1.parameters()}, {'params': model_target.fc2.parameters()}]
target_optimizer = optim.Adam(model_target.parameters(), lr=args.lr)
# target_optimizer_encoder = optim.Adam(target_optimizer_encoder_params, lr=args.lr)
source_optimizer = optim.Adam(model_source.parameters(), lr=args.lr)
d_optimizer = optim.Adam(discriminator_model.parameters(), lr=args.lr)

criterion = nn.BCELoss()

if args.source == 'mnist':
    tests = Tests(model_source, model_target, classifyMNIST, 'mnist',
                  'fashionMnist', args, graph)
elif args.source == 'fashionMnist':
    tests = Tests(model_source, model_target, classifyMNIST, 'fashionMnist',
                  'mnist', args, graph)
Ejemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda',
                        default=False,
                        action='store_true',
                        help='Enable CUDA')
    args = parser.parse_args()
    use_cuda = True if args.cuda and torch.cuda.is_available() else False

    netG = Generator(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda)
    netD = Discriminator(VOCAB_SIZE, D_EMB_SIZE, D_NUM_CLASSES, D_FILTER_SIZES,
                         D_NUM_FILTERS, DROPOUT, use_cuda)
    oracle = Oracle(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda)

    if use_cuda:
        netG, netD, oracle = netG.cuda(), netD.cuda(), oracle.cuda()

    netG.create_optim(G_LR)
    netD.create_optim(D_LR, D_L2_REG)

    # generating synthetic data
    print('Generating data...')
    generate_samples(oracle, BATCH_SIZE, GENERATED_NUM, REAL_FILE)

    # pretrain generator
    gen_set = GeneratorDataset(REAL_FILE)
    genloader = DataLoader(dataset=gen_set,
                           batch_size=BATCH_SIZE,
                           shuffle=True)

    print('\nPretraining generator...\n')
    for epoch in range(PRE_G_EPOCHS):
        loss = netG.pretrain(genloader)
        print('Epoch {} pretrain generator training loss: {}'.format(
            epoch + 1, loss))

        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        val_set = GeneratorDataset(EVAL_FILE)
        valloader = DataLoader(dataset=val_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)
        loss = oracle.val(valloader)
        print('Epoch {} pretrain generator val loss: {}'.format(
            epoch + 1, loss))

    # pretrain discriminator
    print('\nPretraining discriminator...\n')
    for epoch in range(PRE_D_EPOCHS):
        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE)
        dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE)
        disloader = DataLoader(dataset=dis_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)

        for k_step in range(K_STEPS):
            loss = netD.dtrain(disloader)
            print(
                'Epoch {} K-step {} pretrain discriminator training loss: {}'.
                format(epoch + 1, k_step + 1, loss))

    print('\nStarting adversarial training...')
    for epoch in range(TOTAL_EPOCHS):

        nets = [copy.deepcopy(netG) for _ in range(POPULATION_SIZE)]
        population = [(net, evaluate(net, netD)) for net in nets]
        for g_step in range(G_STEPS):
            t_start = time.time()
            population.sort(key=lambda p: p[1], reverse=True)
            rewards = [p[1] for p in population[:PARENTS_COUNT]]
            reward_mean = np.mean(rewards)
            reward_max = np.max(rewards)
            reward_std = np.std(rewards)
            print(
                "Epoch %d step %d: reward_mean=%.2f, reward_max=%.2f, reward_std=%.2f, time=%.2f s"
                % (epoch, g_step, reward_mean, reward_max, reward_std,
                   time.time() - t_start))

            elite = population[0]
            # generate next population
            prev_population = population
            population = [elite]
            for _ in range(POPULATION_SIZE - 1):
                parent_idx = np.random.randint(0, PARENTS_COUNT)
                parent = prev_population[parent_idx][0]
                net = mutate_net(parent, use_cuda)
                fitness = evaluate(parent, netD)
                population.append((net, fitness))

        netG = elite[0]

        for d_step in range(D_STEPS):
            # train discriminator
            generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE)
            dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE)
            disloader = DataLoader(dataset=dis_set,
                                   batch_size=BATCH_SIZE,
                                   shuffle=True)

            for k_step in range(K_STEPS):
                loss = netD.dtrain(disloader)
                print(
                    'D_step {}, K-step {} adversarial discriminator training loss: {}'
                    .format(d_step + 1, k_step + 1, loss))

        generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        val_set = GeneratorDataset(EVAL_FILE)
        valloader = DataLoader(dataset=val_set,
                               batch_size=BATCH_SIZE,
                               shuffle=True)
        loss = oracle.val(valloader)
        print('Epoch {} adversarial generator val loss: {}'.format(
            epoch + 1, loss))
Ejemplo n.º 13
0
def train(config):
    gpu_manage(config)

    train_dataset = Dataset(config.train_dir)
    val_dataset = Dataset(config.val_dir)
    training_data_loader = DataLoader(dataset=train_dataset,
                                      num_workers=config.threads,
                                      batch_size=config.batchsize,
                                      shuffle=True)
    val_data_loader = DataLoader(dataset=val_dataset,
                                 num_workers=config.threads,
                                 batch_size=config.test_batchsize,
                                 shuffle=False)

    gen = UNet(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids)
    if config.gen_init is not None:
        param = torch.load(config.gen_init)
        gen.load_state_dict(param)
        print('load {} as pretrained model'.format(config.gen_init))

    dis = Discriminator(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids)
    if config.dis_init is not None:
        param = torch.load(config.dis_init)
        dis.load_state_dict(param)
        print('load {} as pretrained model'.format(config.dis_init))

    opt_gen = optim.Adam(gen.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001)
    opt_dis = optim.Adam(dis.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001)

    real_a = torch.FloatTensor(config.batchsize, config.in_ch, 256, 256)
    real_b = torch.FloatTensor(config.batchsize, config.out_ch, 256, 256)

    criterionL1 = nn.L1Loss()
    criterionMSE = nn.MSELoss()
    criterionSoftplus = nn.Softplus()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if config.cuda:
        gen = gen.cuda(0)
        dis = dis.cuda(0)
        criterionL1 = criterionL1.cuda(0)
        criterionMSE = criterionMSE.cuda(0)
        criterionSoftplus = criterionSoftplus.cuda(0)
        real_a = real_a.cuda(0)
        real_b = real_b.cuda(0)

    real_a = Variable(real_a)
    real_b = Variable(real_b)

    logreport = LogReport(log_dir=config.out_dir)
    testreport = TestReport(log_dir=config.out_dir)

    for epoch in range(1, config.epoch + 1):
        print('Epoch', epoch, datetime.now())
        for iteration, batch in enumerate(tqdm(training_data_loader)):
            real_a, real_b = batch[0], batch[1]
            real_a = F.interpolate(real_a, size=256).to(device)
            real_b = F.interpolate(real_b, size=256).to(device)
            fake_b = gen.forward(real_a)

            # Update D
            opt_dis.zero_grad()

            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = dis.forward(fake_ab.detach())
            batchsize, _, w, h = pred_fake.size()

            real_ab = torch.cat((real_a, real_b), 1)
            pred_real = dis.forward(real_ab)

            loss_d_fake = torch.sum(criterionSoftplus(pred_fake)) / batchsize / w / h
            loss_d_real = torch.sum(criterionSoftplus(-pred_real)) / batchsize / w / h
            loss_d = loss_d_fake + loss_d_real
            loss_d.backward()

            if epoch % config.minimax == 0:
                opt_dis.step()

            # Update G
            opt_gen.zero_grad()
            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = dis.forward(fake_ab)

            loss_g_gan = torch.sum(criterionSoftplus(-pred_fake)) / batchsize / w / h
            loss_g = loss_g_gan + criterionL1(fake_b, real_b) * config.lamb
            loss_g.backward()

            opt_gen.step()

            if iteration % 100 == 0:
                logreport({
                    'epoch': epoch,
                    'iteration': len(training_data_loader) * (epoch - 1) + iteration,
                    'gen/loss': loss_g.item(),
                    'dis/loss': loss_d.item(),
                })

        with torch.no_grad():
            log_test = test(config, val_data_loader, gen, criterionMSE, epoch)
            testreport(log_test)

        if epoch % config.snapshot_interval == 0:
            checkpoint(config, epoch, gen, dis)

        logreport.save_lossgraph()
        testreport.save_lossgraph()

    print('Done', datetime.now())