Ejemplo n.º 1
0
class ModuleTrain:
    def __init__(self, opt, best_loss=0.2):
        self.opt = opt
        self.best_loss = best_loss  # 正确率这个值,才会保存模型

        self.netd = Discriminator(self.opt)
        self.netg = Generator(self.opt)
        self.use_gpu = False

        # 加载模型
        if os.path.exists(self.opt.netd_path):
            self.load_netd(self.opt.netd_path)
        else:
            print('[Load model] error: %s not exist !!!' % self.opt.netd_path)
        if os.path.exists(self.opt.netg_path):
            self.load_netg(self.opt.netg_path)
        else:
            print('[Load model] error: %s not exist !!!' % self.opt.netg_path)

        # DataLoader初始化
        self.transform_train = T.Compose([
            T.Resize((self.opt.img_size, self.opt.img_size)),
            T.ToTensor(),
            T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]),
        ])
        train_dataset = ImageFolder(root=self.opt.data_path,
                                    transform=self.transform_train)
        self.train_loader = DataLoader(dataset=train_dataset,
                                       batch_size=self.opt.batch_size,
                                       shuffle=True,
                                       num_workers=self.opt.num_workers,
                                       drop_last=True)

        # 优化器和损失函数
        # self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.5)
        self.optimizer_g = optim.Adam(self.netg.parameters(),
                                      lr=self.opt.lr1,
                                      betas=(self.opt.beta1, 0.999))
        self.optimizer_d = optim.Adam(self.netd.parameters(),
                                      lr=self.opt.lr2,
                                      betas=(self.opt.beta1, 0.999))
        self.criterion = torch.nn.BCELoss()

        self.true_labels = Variable(torch.ones(self.opt.batch_size))
        self.fake_labels = Variable(torch.zeros(self.opt.batch_size))
        self.fix_noises = Variable(
            torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
        self.noises = Variable(
            torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))

        # gpu or cpu
        if self.opt.use_gpu and torch.cuda.is_available():
            self.use_gpu = True
        else:
            self.use_gpu = False
        if self.use_gpu:
            print('[use gpu] ...')
            self.netd.cuda()
            self.netg.cuda()
            self.criterion.cuda()
            self.true_labels = self.true_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.fix_noises = self.fix_noises.cuda()
            self.noises = self.noises.cuda()
        else:
            print('[use cpu] ...')

        pass

    def train(self, save_best=True):
        print('[train] epoch: %d' % self.opt.max_epoch)
        for epoch_i in range(self.opt.max_epoch):
            loss_netd = 0.0
            loss_netg = 0.0
            correct = 0

            print('================================================')
            for ii, (img, target) in enumerate(self.train_loader):  # 训练
                real_img = Variable(img)
                if self.opt.use_gpu:
                    real_img = real_img.cuda()

                # 训练判别器
                if (ii + 1) % self.opt.d_every == 0:
                    self.optimizer_d.zero_grad()
                    # 尽可能把真图片判别为1
                    output = self.netd(real_img)
                    error_d_real = self.criterion(output, self.true_labels)
                    error_d_real.backward()

                    # 尽可能把假图片判别为0
                    self.noises.data.copy_(
                        torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
                    fake_img = self.netg(self.noises).detach()  # 根据噪声生成假图
                    fake_output = self.netd(fake_img)
                    error_d_fake = self.criterion(fake_output,
                                                  self.fake_labels)
                    error_d_fake.backward()
                    self.optimizer_d.step()

                    loss_netd += (error_d_real.item() + error_d_fake.item())

                # 训练生成器
                if (ii + 1) % self.opt.g_every == 0:
                    self.optimizer_g.zero_grad()
                    self.noises.data.copy_(
                        torch.randn(self.opt.batch_size, self.opt.nz, 1, 1))
                    fake_img = self.netg(self.noises)
                    fake_output = self.netd(fake_img)
                    # 尽可能让判别器把假图片也判别为1
                    error_g = self.criterion(fake_output, self.true_labels)
                    error_g.backward()
                    self.optimizer_g.step()

                    loss_netg += error_g

            loss_netd /= (len(self.train_loader) * 2)
            loss_netg /= len(self.train_loader)
            print('[Train] Epoch: {} \tNetD Loss: {:.6f} \tNetG Loss: {:.6f}'.
                  format(epoch_i, loss_netd, loss_netg))
            if save_best is True:
                if (loss_netg + loss_netd) / 2 < self.best_loss:
                    self.best_loss = (loss_netg + loss_netd) / 2
                    self.save(self.netd, self.opt.best_netd_path)  # 保存最好的模型
                    self.save(self.netg, self.opt.best_netg_path)  # 保存最好的模型
                    print('[save best] ...')

            # self.vis()

            if (epoch_i + 1) % 5 == 0:
                self.image_gan()

        self.save(self.netd, self.opt.netd_path)  # 保存最好的模型
        self.save(self.netg, self.opt.netg_path)  # 保存最好的模型

    def vis(self):
        fix_fake_imgs = self.netg(self.opt.fix_noises)
        visdom.images(fix_fake_imgs.data.cpu().numpy()[:64] * 0.5 + 0.5,
                      win='fixfake')

    def image_gan(self):
        noises = torch.randn(self.opt.gen_search_num, self.opt.nz, 1,
                             1).normal_(self.opt.gen_mean, self.opt.gen_std)
        with torch.no_grad():
            noises = Variable(noises)

        if self.use_gpu:
            noises = noises.cuda()

        fake_img = self.netg(noises)
        scores = self.netd(fake_img).data
        indexs = scores.topk(self.opt.gen_num)[1]
        result = list()
        for ii in indexs:
            result.append(fake_img.data[ii])

        torchvision.utils.save_image(torch.stack(result),
                                     self.opt.gen_img,
                                     normalize=True,
                                     range=(-1, 1))

        #     # print(correct)
        #     # print(len(self.train_loader.dataset))
        #     train_loss /= len(self.train_loader)
        #     acc = float(correct) / float(len(self.train_loader.dataset))
        #     print('[Train] Epoch: {} \tLoss: {:.6f}\tAcc: {:.6f}\tlr: {}'.format(epoch_i, train_loss, acc, self.lr))
        #
        #     test_acc = self.test()
        #     if save_best is True:
        #         if test_acc > self.best_acc:
        #             self.best_acc = test_acc
        #             str_list = self.model_file.split('.')
        #             best_model_file = ""
        #             for str_index in range(len(str_list)):
        #                 best_model_file = best_model_file + str_list[str_index]
        #                 if str_index == (len(str_list) - 2):
        #                     best_model_file += '_best'
        #                 if str_index != (len(str_list) - 1):
        #                     best_model_file += '.'
        #             self.save(best_model_file)                                  # 保存最好的模型
        #
        # self.save(self.model_file)

    def test(self):
        test_loss = 0.0
        correct = 0

        time_start = time.time()
        # 测试集
        for data, target in self.test_loader:
            data, target = Variable(data), Variable(target)

            if self.use_gpu:
                data = data.cuda()
                target = target.cuda()

            output = self.model(data)
            # sum up batch loss
            if self.use_gpu:
                loss = self.loss(output, target)
            else:
                loss = self.loss(output, target)
            test_loss += loss.item()

            predict = torch.argmax(output, 1)
            correct += (predict == target).sum().data

        time_end = time.time()
        time_avg = float(time_end - time_start) / float(
            len(self.test_loader.dataset))
        test_loss /= len(self.test_loader)
        acc = float(correct) / float(len(self.test_loader.dataset))

        print('[Test] set: Test loss: {:.6f}\t Acc: {:.6f}\t time: {:.6f} \n'.
              format(test_loss, acc, time_avg))
        return acc

    def load_netd(self, name):
        print('[Load model netd] %s ...' % name)
        self.netd.load_state_dict(torch.load(name))

    def load_netg(self, name):
        print('[Load model netg] %s ...' % name)
        self.netg.load_state_dict(torch.load(name))

    def save(self, model, name):
        print('[Save model] %s ...' % name)
        torch.save(model.state_dict(), name)
Ejemplo n.º 2
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        self.noise_n = config.noise_n
        self.G_last_act = last_act(config.G_last_act)
        self.D_out_n = config.D_out_n
        self.D_last_act = last_act(config.D_last_act)

        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.D_train_step = config.D_train_step
        self.save_image_step = config.save_image_step
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step

        self.model_save_path = config.model_save_path
        self.log_save_path = config.log_save_path
        self.image_save_path = config.image_save_path

        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.build_model()

        if self.use_tensorboard is not None:
            self.build_tensorboard()
        if self.pretrained_model is not None:
            if len(self.pretrained_model) != 2:
                raise "must have both G and D pretrained parameters, and G is first, D is second"
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.noise_n, self.G_last_act)
        self.D = Discriminator(self.D_out_n, self.D_last_act)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr,
                                            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def build_tensorboard(self):
        from commons.logger import Logger
        self.logger = Logger(self.log_save_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(self.pretrained_model[0]))
        self.D.load_state_dict(torch.load(self.pretrained_model[1]))

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def train(self):
        bce_loss = nn.BCELoss()

        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                real_label = self.to_var(
                    torch.FloatTensor(batch_size).fill_(1.))
                fake_label = self.to_var(
                    torch.FloatTensor(batch_size).fill_(0.))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = bce_loss(real_out, real_label)
                D_fake = bce_loss(fake_out, fake_label)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # Train G
                if (i + 1) % self.D_train_step == 0:
                    # noise_x = self.to_var(torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                    fake_out = self.D(self.G(noise_x))
                    G_loss = bce_loss(fake_out, real_label)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(32, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    fake_image.data,
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
Ejemplo n.º 3
0
def main_worker(args):

    ################
    # Define model #
    ################
    # 4/3 : scale factor in the paper
    scale_factor = 4 / 3
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    discriminator = Discriminator()
    generator = Generator(args.img_size_min, args.num_scale, scale_factor)

    ######################
    # Loss and Optimizer #
    ######################
    d_opt = mindspore.nn.Adam(
        discriminator.sub_discriminators[0].get_parameters(), 5e-4, 0.5, 0.999)
    g_opt = mindspore.nn.Adam(generator.sub_generators[0].get_parameters(),
                              5e-4, 0.5, 0.999)

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = mindspore.load_checkpoint(
                load_file)  # MPS map_location='cpu'#
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            # MPS Adm.load_state_dict是否存在
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)
    train_sampler = None

    train_loader = mindspore.DatasetHelper(train_dataset)  # MPS 可能需要调参数

    ######################
    # Validate and Train #
    ######################
    op1 = mindspore.ops.Pad(((5, 5), (5, 5)))
    op2 = mindspore.ops.Pad(((5, 5), (5, 5)))
    z_fix_list = [op1(mindspore.ops.StandardNormal(3, args.size_list[0]))]
    zero_list = [
        op2(mindspore.ops.Zeros(3, args.size_list[zeros_idx]))
        for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
    record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
    record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
    record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
    record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
    record_txt.close()
    networks = [discriminator, generator]

    for stage in range(args.stage, args.num_scale + 1):

        trainSinGAN(train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        validateSinGAN(train_loader, networks, stage, args,
                       {"z_rec": z_fix_list})
        discriminator.progress()
        generator.progress()

        # Update the networks at finest scale
        d_opt = mindspore.nn.Adam(
            discriminator.sub_discriminators[
                discriminator.current_scale].parameters(), 5e-4, 0.5, 0.999)
        g_opt = mindspore.nn.Adam(
            generator.sub_generators[generator.current_scale].parameters(),
            5e-4, 0.5, 0.999)
        ##############
        # Save model #
        ##############
        if stage == 0:
            check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                              "a+")
        save_checkpoint(
            {
                'stage': stage + 1,
                'D_state_dict': discriminator.state_dict(),
                'G_state_dict': generator.state_dict(),
                'd_optimizer': d_opt.state_dict(),
                'g_optimizer': g_opt.state_dict(),
                'img_to_use': args.img_to_use
            }, check_list, args.log_dir, stage + 1)
        if stage == args.num_scale:
            check_list.close()
Ejemplo n.º 4
0
def main_worker(gpu, ngpus_per_node, args):
    if len(args.gpu) == 1:
        args.gpu = 0
    else:
        args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl',
                                init_method='tcp://127.0.0.1:' + args.port,
                                world_size=args.world_size,
                                rank=args.rank)

    ################
    # Define model #
    ################
    # 4/3 : scale factor in the paper
    scale_factor = 4 / 3
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    discriminator = Discriminator()
    generator = Generator(args.img_size_min, args.num_scale, scale_factor)

    networks = [discriminator, generator]

    if args.distributed:
        if args.gpu is not None:
            print('Distributed to', args.gpu)
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            networks = [
                torch.nn.parallel.DistributedDataParallel(
                    x, device_ids=[args.gpu], output_device=args.gpu)
                for x in networks
            ]
        else:
            networks = [x.cuda() for x in networks]
            networks = [
                torch.nn.parallel.DistributedDataParallel(x) for x in networks
            ]

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        networks = [x.cuda(args.gpu) for x in networks]
    else:
        networks = [torch.nn.DataParallel(x).cuda() for x in networks]

    discriminator, generator, = networks

    ######################
    # Loss and Optimizer #
    ######################
    if args.distributed:
        d_opt = torch.optim.Adam(
            discriminator.module.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(
            generator.module.sub_generators[0].parameters(), 5e-4,
            (0.5, 0.999))
    else:
        d_opt = torch.optim.Adam(
            discriminator.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(generator.sub_generators[0].parameters(),
                                 5e-4, (0.5, 0.999))

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = torch.load(load_file, map_location='cpu')
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            networks = [discriminator, generator]

            if args.distributed:
                if args.gpu is not None:
                    print('Distributed to', args.gpu)
                    torch.cuda.set_device(args.gpu)
                    networks = [x.cuda(args.gpu) for x in networks]
                    args.batch_size = int(args.batch_size / ngpus_per_node)
                    args.workers = int(args.workers / ngpus_per_node)
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(
                            x, device_ids=[args.gpu], output_device=args.gpu)
                        for x in networks
                    ]
                else:
                    networks = [x.cuda() for x in networks]
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(x)
                        for x in networks
                    ]

            elif args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
            else:
                networks = [torch.nn.DataParallel(x).cuda() for x in networks]

            discriminator, generator, = networks

            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    cudnn.benchmark = True

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    ######################
    # Validate and Train #
    ######################
    z_fix_list = [
        F.pad(torch.randn(args.batch_size, 3, args.size_list[0],
                          args.size_list[0]), [5, 5, 5, 5],
              value=0)
    ]
    zero_list = [
        F.pad(torch.zeros(args.batch_size, 3, args.size_list[zeros_idx],
                          args.size_list[zeros_idx]), [5, 5, 5, 5],
              value=0) for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
        record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
        record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
        record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
        record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
        record_txt.close()

    for stage in range(args.stage, args.num_scale + 1):
        if args.distributed:
            train_sampler.set_epoch(stage)

        trainSinGAN(train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        validateSinGAN(train_loader, networks, stage, args,
                       {"z_rec": z_fix_list})

        if args.distributed:
            discriminator.module.progress()
            generator.module.progress()
        else:
            discriminator.progress()
            generator.progress()

        networks = [discriminator, generator]

        if args.distributed:
            if args.gpu is not None:
                print('Distributed', args.gpu)
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int(args.workers / ngpus_per_node)
                networks = [
                    torch.nn.parallel.DistributedDataParallel(
                        x, device_ids=[args.gpu], output_device=args.gpu)
                    for x in networks
                ]
            else:
                networks = [x.cuda() for x in networks]
                networks = [
                    torch.nn.parallel.DistributedDataParallel(x)
                    for x in networks
                ]

        elif args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
        else:
            networks = [torch.nn.DataParallel(x).cuda() for x in networks]

        discriminator, generator, = networks

        # Update the networks at finest scale
        if args.distributed:
            for net_idx in range(generator.module.current_scale):
                for param in generator.module.sub_generators[
                        net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.module.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.module.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.module.sub_generators[
                    generator.current_scale].parameters(), 5e-4, (0.5, 0.999))
        else:
            for net_idx in range(generator.current_scale):
                for param in generator.sub_generators[net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.sub_generators[generator.current_scale].parameters(),
                5e-4, (0.5, 0.999))

        ##############
        # Save model #
        ##############
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if stage == 0:
                check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                                  "a+")
            save_checkpoint(
                {
                    'stage': stage + 1,
                    'D_state_dict': discriminator.state_dict(),
                    'G_state_dict': generator.state_dict(),
                    'd_optimizer': d_opt.state_dict(),
                    'g_optimizer': g_opt.state_dict(),
                    'img_to_use': args.img_to_use
                }, check_list, args.log_dir, stage + 1)
            if stage == args.num_scale:
                check_list.close()
Ejemplo n.º 5
0
def main():
    env = DialogEnvironment()
    experiment_name = args.logdir.split('/')[1] #model name

    torch.manual_seed(args.seed)

    #TODO
    actor = Actor(hidden_size=args.hidden_size,num_layers=args.num_layers,device='cuda',input_size=args.input_size,output_size=args.input_size)
    critic = Critic(hidden_size=args.hidden_size,num_layers=args.num_layers,input_size=args.input_size,seq_len=args.seq_len)
    discrim = Discriminator(hidden_size=args.hidden_size,num_layers=args.hidden_size,input_size=args.input_size,seq_len=args.seq_len)
    
    actor.to(device), critic.to(device), discrim.to(device)
    
    actor_optim = optim.Adam(actor.parameters(), lr=args.learning_rate)
    critic_optim = optim.Adam(critic.parameters(), lr=args.learning_rate, 
                              weight_decay=args.l2_rate) 
    discrim_optim = optim.Adam(discrim.parameters(), lr=args.learning_rate)

    # load demonstrations

    writer = SummaryWriter(args.logdir)

    if args.load_model is not None: #TODO
        saved_ckpt_path = os.path.join(os.getcwd(), 'save_model', str(args.load_model))
        ckpt = torch.load(saved_ckpt_path)

        actor.load_state_dict(ckpt['actor'])
        critic.load_state_dict(ckpt['critic'])
        discrim.load_state_dict(ckpt['discrim'])


    
    episodes = 0
    train_discrim_flag = True

    for iter in range(args.max_iter_num):
        actor.eval(), critic.eval()
        memory = deque()

        steps = 0
        scores = []
        similarity_scores = []
        while steps < args.total_sample_size: 
            scores = []
            similarity_scores = []
            state, expert_action, raw_state, raw_expert_action = env.reset()
            score = 0
            similarity_score = 0
            state = state[:args.seq_len,:]
            expert_action = expert_action[:args.seq_len,:]
            state = state.to(device)
            expert_action = expert_action.to(device)
            for _ in range(10000): 

                steps += 1

                mu, std = actor(state.resize(1,args.seq_len,args.input_size)) #TODO: gotta be a better way to resize. 
                action = get_action(mu.cpu(), std.cpu())[0]
                for i in range(5):
                    emb_sum = expert_action[i,:].sum().cpu().item()
                    if emb_sum == 0:
                       # print(i)
                        action[i:,:] = 0 # manual padding
                        break

                done= env.step(action)
                irl_reward = get_reward(discrim, state, action, args)
                if done:
                    mask = 0
                else:
                    mask = 1


                memory.append([state, torch.from_numpy(action).to(device), irl_reward, mask,expert_action])
                score += irl_reward
                similarity_score += get_cosine_sim(expert=expert_action,action=action.squeeze(),seq_len=5)
                #print(get_cosine_sim(s1=expert_action,s2=action.squeeze(),seq_len=5),'sim')
                if done:
                    break

            episodes += 1
            scores.append(score)
            similarity_scores.append(similarity_score)

        score_avg = np.mean(scores)
        similarity_score_avg = np.mean(similarity_scores)
        print('{}:: {} episode score is {:.2f}'.format(iter, episodes, score_avg))
        print('{}:: {} episode similarity score is {:.2f}'.format(iter, episodes, similarity_score_avg))

        actor.train(), critic.train(), discrim.train()
        if train_discrim_flag:
            expert_acc, learner_acc = train_discrim(discrim, memory, discrim_optim, args) 
            print("Expert: %.2f%% | Learner: %.2f%%" % (expert_acc * 100, learner_acc * 100))
            writer.add_scalar('log/expert_acc', float(expert_acc), iter) #logg
            writer.add_scalar('log/learner_acc', float(learner_acc), iter) #logg
            writer.add_scalar('log/avg_acc', float(learner_acc + expert_acc)/2, iter) #logg
            if args.suspend_accu_exp is not None: #only if not None do we check.
                if expert_acc > args.suspend_accu_exp and learner_acc > args.suspend_accu_gen:
                    train_discrim_flag = False

        train_actor_critic(actor, critic, memory, actor_optim, critic_optim, args)
        writer.add_scalar('log/score', float(score_avg), iter)
        writer.add_scalar('log/similarity_score', float(similarity_score_avg), iter)
        writer.add_text('log/raw_state', raw_state[0],iter)
        raw_action = get_raw_action(action) #TODO
        writer.add_text('log/raw_action', raw_action,iter)
        writer.add_text('log/raw_expert_action', raw_expert_action,iter)

        if iter % 100:
            score_avg = int(score_avg)
            # Open a file with access mode 'a'
            file_object = open(experiment_name+'.txt', 'a')

            result_str = str(iter) + '|' + raw_state[0] + '|' + raw_action + '|' + raw_expert_action + '\n'
            # Append at the end of file
            file_object.write(result_str)
            # Close the file
            file_object.close()

            model_path = os.path.join(os.getcwd(),'save_model')
            if not os.path.isdir(model_path):
                os.makedirs(model_path)

            ckpt_path = os.path.join(model_path, experiment_name + '_ckpt_'+ str(score_avg)+'.pth.tar')

            save_checkpoint({
                'actor': actor.state_dict(),
                'critic': critic.state_dict(),
                'discrim': discrim.state_dict(),
                'args': args,
                'score': score_avg,
            }, filename=ckpt_path)
Ejemplo n.º 6
0
class GanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(GanTrainer, self).__init__(train_loader, test_loader,
                                         valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the auto-encoder
        self.use_autoencoder = False
        if trainer_args.autoencoder_path and os.path.exists(
                trainer_args.autoencoder_path):
            self.use_autoencoder = True
            self.autoencoder = AutoEncoder(general_args=general_args).to(
                self.device)
            self.load_pretrained_autoencoder(trainer_args.autoencoder_path)
            self.autoencoder.eval()

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.generator_time_criterion = nn.MSELoss()
        self.generator_frequency_criterion = nn.MSELoss()
        self.generator_autoencoder_criterion = nn.MSELoss()

        # Define labels
        self.real_label = 1
        self.generated_label = 0

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_freq = trainer_args.lambda_freq
        self.lambda_autoencoder = trainer_args.lambda_autoencoder

        # Spectrogram converter
        self.spectrogram = Spectrogram(normalized=True).to(self.device)

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Boolean if the generator receives the feedback from the discriminator
        self.use_adversarial = trainer_args.use_adversarial

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def load_pretrained_autoencoder(self, autoencoder_path):
        """
        Loads a pre-trained auto-encoder. Can be used to infer
        :param autoencoder_path: location of the pre-trained auto-encoder (string).
        :return: None
        """
        checkpoint = torch.load(autoencoder_path, map_location=self.device)
        self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])

    def train(self, epochs):
        """
        Trains the GAN for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                self.generator.train()
                self.discriminator.train()
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)
                batch_size = input_batch.shape[0]

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                # Train the discriminator with real data
                self.discriminator_optimizer.zero_grad()
                label = torch.full((batch_size, ),
                                   self.real_label,
                                   device=self.device)
                output = self.discriminator(target_batch)

                # Compute and store the discriminator loss on real data
                loss_discriminator_real = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['real'].append(
                    loss_discriminator_real.item())
                loss_discriminator_real.backward()

                # Train the discriminator with fake data
                generated_batch = self.generator(input_batch)
                label.fill_(self.generated_label)
                output = self.discriminator(generated_batch.detach())

                # Compute and store the discriminator loss on fake data
                loss_discriminator_generated = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['fake'].append(
                    loss_discriminator_generated.item())
                loss_discriminator_generated.backward()

                # Update the discriminator weights
                self.discriminator_optimizer.step()

                ############################
                # Update G network: maximize log(D(G(z)))
                ###########################
                self.generator_optimizer.zero_grad()

                # Get the spectrogram
                specgram_target_batch = self.spectrogram(target_batch)
                specgram_fake_batch = self.spectrogram(generated_batch)

                # Fake labels are real for the generator cost
                label.fill_(self.real_label)
                output = self.discriminator(generated_batch)

                # Compute the generator loss on fake data
                # Get the adversarial loss
                loss_generator_adversarial = torch.zeros(size=[1],
                                                         device=self.device)
                if self.use_adversarial:
                    loss_generator_adversarial = self.adversarial_criterion(
                        output, torch.unsqueeze(label, dim=1))
                self.train_losses['generator_adversarial'].append(
                    loss_generator_adversarial.item())

                # Get the L2 loss in time domain
                loss_generator_time = self.generator_time_criterion(
                    generated_batch, target_batch)
                self.train_losses['time_l2'].append(loss_generator_time.item())

                # Get the L2 loss in frequency domain
                loss_generator_frequency = self.generator_frequency_criterion(
                    specgram_fake_batch, specgram_target_batch)
                self.train_losses['freq_l2'].append(
                    loss_generator_frequency.item())

                # Get the L2 loss in embedding space
                loss_generator_autoencoder = torch.zeros(size=[1],
                                                         device=self.device,
                                                         requires_grad=True)
                if self.use_autoencoder:
                    # Get the embeddings
                    _, embedding_target_batch = self.autoencoder(target_batch)
                    _, embedding_generated_batch = self.autoencoder(
                        generated_batch)
                    loss_generator_autoencoder = self.generator_autoencoder_criterion(
                        embedding_generated_batch, embedding_target_batch)
                    self.train_losses['autoencoder_l2'].append(
                        loss_generator_autoencoder.item())

                # Combine the different losses
                loss_generator = self.lambda_adv * loss_generator_adversarial + loss_generator_time + \
                                 self.lambda_freq * loss_generator_frequency + \
                                 self.lambda_autoencoder * loss_generator_autoencoder

                # Back-propagate and update the generator weights
                loss_generator.backward()
                self.generator_optimizer.step()

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Frequency: {} \n' \
                              '\t\t Autoencoder {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Real {} \n' \
                              '\t\t Fake {} \n'.format(i,
                                                       loss_generator_time.item(),
                                                       loss_generator_frequency.item(),
                                                       loss_generator_autoencoder.item(),
                                                       loss_generator_adversarial.item(),
                                                       loss_discriminator_real.item(),
                                                       loss_discriminator_generated.item())
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()
            # if self.need_saving:
            #     self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': [], 'freq_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            # Get the spectrogram
            specgram_target_batch = self.spectrogram(target_batch)
            specgram_generated_batch = self.spectrogram(generated_batch)

            loss_generator_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_generator_time.item())
            loss_generator_frequency = self.generator_frequency_criterion(
                specgram_generated_batch, specgram_target_batch)
            batch_losses['freq_l2'].append(loss_generator_frequency.item())

        # Store the validation losses
        self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2']))
        self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n' \
                  '\t Frequency: {} \n'.format(self.epoch,
                                               np.mean(np.mean(batch_losses['time_l2'])),
                                               np.mean(np.mean(batch_losses['freq_l2'])))
        print(message)

        # Check if the loss is decreasing
        self.check_improvement()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, self.savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(
            checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        self.discriminator_optimizer.load_state_dict(
            checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
Ejemplo n.º 7
0
class Model(object):
    def __init__(self, opt):
        super(Model, self).__init__()

        # Generator
        self.gen = Generator(opt).cuda(opt.gpu_id)

        self.gen_params = self.gen.parameters()

        num_params = 0
        for p in self.gen.parameters():
            num_params += p.numel()
        print(self.gen)
        print(num_params)

        # Discriminator
        self.dis = Discriminator(opt).cuda(opt.gpu_id)

        self.dis_params = self.dis.parameters()

        num_params = 0
        for p in self.dis.parameters():
            num_params += p.numel()
        print(self.dis)
        print(num_params)

        # Regressor
        if opt.mse_weight:
            self.reg = torch.load('data/utils/classifier.pth').cuda(
                opt.gpu_id).eval()
        else:
            self.reg = None

        # Losses
        self.criterion_gan = GANLoss(opt, self.dis)
        self.criterion_mse = lambda x, y: l1_loss(x, y) * opt.mse_weight

        self.loss_mse = Variable(torch.zeros(1).cuda())
        self.loss_adv = Variable(torch.zeros(1).cuda())
        self.loss = Variable(torch.zeros(1).cuda())

        self.path = opt.experiments_dir + opt.experiment_name + '/checkpoints/'
        self.gpu_id = opt.gpu_id
        self.noise_channels = opt.in_channels - len(opt.input_idx.split(','))

    def forward(self, inputs):

        input, input_orig, target = inputs

        self.input = Variable(input.cuda(self.gpu_id))
        self.input_orig = Variable(input_orig.cuda(self.gpu_id))
        self.target = Variable(target.cuda(self.gpu_id))

        noise = Variable(
            torch.randn(self.input.size(0),
                        self.noise_channels).cuda(self.gpu_id))

        self.fake = self.gen(torch.cat([self.input, noise], 1))

    def backward_G(self):

        # Regressor loss
        if self.reg is not None:

            fake_input = self.reg(self.fake)

            self.loss_mse = self.criterion_mse(fake_input, self.input_orig)

        # GAN loss
        loss_adv, _ = self.criterion_gan(self.fake)

        loss_G = self.loss_mse + loss_adv
        loss_G.backward()

    def backward_D(self):

        loss_adv, self.loss_adv = self.criterion_gan(self.target, self.fake)

        loss_D = loss_adv
        loss_D.backward()

    def train(self):

        self.gen.train()
        self.dis.train()

    def eval(self):

        self.gen.eval()
        self.dis.eval()

    def save_checkpoint(self, epoch):

        torch.save(
            {
                'epoch': epoch,
                'gen_state_dict': self.gen.state_dict(),
                'dis_state_dict': self.dis.state_dict()
            }, self.path + '%d.pkl' % epoch)

    def load_checkpoint(self, path, pretrained=True):

        weights = torch.load(path)

        self.gen.load_state_dict(weights['gen_state_dict'])
        self.dis.load_state_dict(weights['dis_state_dict'])
Ejemplo n.º 8
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        self.noise_n = config.noise_n
        self.G_last_act = last_act(config.G_last_act)
        self.D_out_n = config.D_out_n
        self.D_last_act = last_act(config.D_last_act)

        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.D_train_step = config.D_train_step
        self.save_image_step = config.save_image_step
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step
        self.clip_value = config.clip_value
        self.lambda_gp = config.lambda_gp

        self.model_save_path = config.model_save_path
        self.log_save_path = config.log_save_path
        self.image_save_path = config.image_save_path

        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.build_model()

        if self.use_tensorboard is not None:
            self.build_tensorboard()
        if self.pretrained_model is not None:
            if len(self.pretrained_model) != 2:
                raise "must have both G and D pretrained parameters, and G is first, D is second"
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.noise_n, self.G_last_act)
        self.D = Discriminator(self.D_out_n, self.D_last_act)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr,
                                            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def build_tensorboard(self):
        from commons.logger import Logger
        self.logger = Logger(self.log_save_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(self.pretrained_model[0]))
        self.D.load_state_dict(torch.load(self.pretrained_model[1]))

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def train(self):
        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                label = torch.FloatTensor(batch_size)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = -torch.mean(real_out)
                D_fake = torch.mean(fake_out)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # choose one in below two
                # Clip weights of D
                # for p in self.D.parameters():
                #     p.data.clamp_(-self.clip_value, clip_value)
                # Gradients penalty, WGAP-GP
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                # print(alpha.shape, real_x.shape, fake_x.shape)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                gp_out = self.D(interpolated)
                grad = torch.autograd.grad(outputs=gp_out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               gp_out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)
                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.D_optimizer.step()
                # Train G
                if (i + 1) % self.D_train_step == 0:
                    fake_out = self.D(self.G(noise_x))
                    G_loss = -torch.mean(fake_out)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(16, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    self.denorm(fake_image.data),
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
Ejemplo n.º 9
0
class Trainer():
    def __init__(self, config):
        self.batch_size = config.batchSize
        self.epochs = config.epochs

        self.use_cycle_loss = config.cycleLoss
        self.cycle_multiplier = config.cycleMultiplier

        self.use_identity_loss = config.identityLoss
        self.identity_multiplier = config.identityMultiplier

        self.load_models = config.loadModels
        self.data_x_loc = config.dataX
        self.data_y_loc = config.dataY

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.init_models()
        self.init_data_loaders()
        self.g_optimizer = torch.optim.Adam(list(self.G_X.parameters()) +
                                            list(self.G_Y.parameters()),
                                            lr=config.lr)
        self.d_optimizer = torch.optim.Adam(list(self.D_X.parameters()) +
                                            list(self.D_Y.parameters()),
                                            lr=config.lr)
        self.scheduler_g = torch.optim.lr_scheduler.StepLR(self.g_optimizer,
                                                           step_size=1,
                                                           gamma=0.95)

        self.output_path = "./outputs/"
        self.img_width = 256
        self.img_height = 256

    # Load/Construct the models
    def init_models(self):

        self.G_X = Generator(3, 3, nn.InstanceNorm2d)
        self.D_X = Discriminator(3)
        self.G_Y = Generator(3, 3, nn.InstanceNorm2d)
        self.D_Y = Discriminator(3)

        if self.load_models:
            self.G_X.load_state_dict(
                torch.load(self.output_path + "models/G_X",
                           map_location='cpu'))
            self.G_Y.load_state_dict(
                torch.load(self.output_path + "models/G_Y",
                           map_location='cpu'))
            self.D_X.load_state_dict(
                torch.load(self.output_path + "models/D_X",
                           map_location='cpu'))
            self.D_Y.load_state_dict(
                torch.load(self.output_path + "models/D_Y",
                           map_location='cpu'))
        else:
            self.G_X.apply(init_func)
            self.G_Y.apply(init_func)
            self.D_X.apply(init_func)
            self.D_Y.apply(init_func)

        self.G_X.to(self.device)
        self.G_Y.to(self.device)
        self.D_X.to(self.device)
        self.D_Y.to(self.device)

    # Initialize data loaders and image transformer
    def init_data_loaders(self):

        transform = transforms.Compose([
            transforms.Resize((self.img_width, self.img_height)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        X_folder = torchvision.datasets.ImageFolder(self.data_x_loc, transform)
        self.X_loader = torch.utils.data.DataLoader(X_folder,
                                                    batch_size=self.batch_size,
                                                    shuffle=True)

        Y_folder = torchvision.datasets.ImageFolder(self.data_y_loc, transform)
        self.Y_loader = torch.utils.data.DataLoader(Y_folder,
                                                    batch_size=self.batch_size,
                                                    shuffle=True)

    def save_models(self):
        torch.save(self.G_X.state_dict(), self.output_path + "models/G_X")
        torch.save(self.D_X.state_dict(), self.output_path + "models/D_X")
        torch.save(self.G_Y.state_dict(), self.output_path + "models/G_Y")
        torch.save(self.D_Y.state_dict(), self.output_path + "models/D_Y")

    # Reset gradients for all models, needed for between every training
    def reset_gradients(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    # Sample image from training data every %x epoch and save them for judging
    def save_samples(self, epoch):
        x_iter = iter(self.X_loader)
        y_iter = iter(self.Y_loader)

        img_data_x, _ = next(x_iter)
        img_data_y, _ = next(y_iter)

        original_x = np.array(img_data_x[0])
        generated_y = np.array(
            self.G_Y(img_data_x[0].view(1, 3, self.img_width,
                                        self.img_height).to(
                                            self.device)).cpu().detach())[0]

        original_y = np.array(img_data_y[0])
        generated_x = np.array(
            self.G_X(img_data_y[0].view(1, 3, self.img_width,
                                        self.img_height).to(
                                            self.device)).cpu().detach())[0]

        def prepare_image(img):
            img = img.transpose((1, 2, 0))
            return img / 2 + 0.5

        original_x = prepare_image(original_x)
        generated_y = prepare_image(generated_y)

        original_y = prepare_image(original_y)
        generated_x = prepare_image(generated_x)

        plt.imsave('./outputs/samples/original_X_{}.png'.format(epoch),
                   original_x)
        plt.imsave('./outputs/samples/original_Y_{}.png'.format(epoch),
                   original_y)

        plt.imsave('./outputs/samples/generated_X_{}.png'.format(epoch),
                   generated_x)
        plt.imsave('./outputs/samples/generated_Y_{}.png'.format(epoch),
                   generated_y)

    # Training loop
    def train(self):
        D_X_losses = []
        D_Y_losses = []

        G_X_losses = []
        G_Y_losses = []

        for epoch in range(self.epochs):
            print("======")
            print("Epoch {}!".format(epoch + 1))

            # Track progress
            if epoch % 5 == 0:
                self.save_samples(epoch)

            # Paper reduces lr after 100 epochs
            if epoch > 100:
                self.scheduler_g.step()

            for (data_X, _), (data_Y, _) in zip(self.X_loader, self.Y_loader):
                data_X = data_X.to(self.device)
                data_Y = data_Y.to(self.device)

                # =====================================
                # Train Discriminators
                # =====================================

                # Train fake X
                self.reset_gradients()
                fake_X = self.G_X(data_Y)
                out_fake_X = self.D_X(fake_X)
                d_x_f_loss = torch.mean(out_fake_X**2)
                d_x_f_loss.backward()
                self.d_optimizer.step()

                # Train fake Y
                self.reset_gradients()
                fake_Y = self.G_Y(data_X)
                out_fake_Y = self.D_Y(fake_Y)
                d_y_f_loss = torch.mean(out_fake_Y**2)
                d_y_f_loss.backward()
                self.d_optimizer.step()

                # Train true X
                self.reset_gradients()
                out_true_X = self.D_X(data_X)
                d_x_t_loss = torch.mean((out_true_X - 1)**2)
                d_x_t_loss.backward()
                self.d_optimizer.step()

                # Train true Y
                self.reset_gradients()
                out_true_Y = self.D_Y(data_Y)
                d_y_t_loss = torch.mean((out_true_Y - 1)**2)
                d_y_t_loss.backward()
                self.d_optimizer.step()

                D_X_losses.append([
                    d_x_t_loss.cpu().detach().numpy(),
                    d_x_f_loss.cpu().detach().numpy()
                ])
                D_Y_losses.append([
                    d_y_t_loss.cpu().detach().numpy(),
                    d_y_f_loss.cpu().detach().numpy()
                ])

                # =====================================
                # Train GENERATORS
                # =====================================

                # Cycle X -> Y -> X
                self.reset_gradients()

                fake_Y = self.G_Y(data_X)
                out_fake_Y = self.D_Y(fake_Y)

                g_loss1 = torch.mean((out_fake_Y - 1)**2)
                if self.use_cycle_loss:
                    reconst_X = self.G_X(fake_Y)
                    g_loss2 = self.cycle_multiplier * torch.mean(
                        (data_X - reconst_X)**2)

                G_Y_losses.append([
                    g_loss1.cpu().detach().numpy(),
                    g_loss2.cpu().detach().numpy()
                ])
                g_loss = g_loss1 + g_loss2
                g_loss.backward()
                self.g_optimizer.step()

                # Cycle Y -> X -> Y
                self.reset_gradients()

                fake_X = self.G_X(data_Y)
                out_fake_X = self.D_X(fake_X)

                g_loss1 = torch.mean((out_fake_X - 1)**2)
                if self.use_cycle_loss:
                    reconst_Y = self.G_Y(fake_X)
                    g_loss2 = self.cycle_multiplier * torch.mean(
                        (data_Y - reconst_Y)**2)

                G_X_losses.append([
                    g_loss1.cpu().detach().numpy(),
                    g_loss2.cpu().detach().numpy()
                ])
                g_loss = g_loss1 + g_loss2
                g_loss.backward()
                self.g_optimizer.step()

                # =====================================
                # Train image IDENTITY
                # =====================================

                if self.use_identity_loss:
                    self.reset_gradients()

                    # X should be same after G(X)
                    same_X = self.G_X(data_X)
                    g_loss = self.identity_multiplier * torch.mean(
                        (data_X - same_X)**2)
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Y should be same after G(Y)
                    same_Y = self.G_X(data_Y)
                    g_loss = self.identity_multiplier * torch.mean(
                        (data_Y - same_Y)**2)
                    g_loss.backward()
                    self.g_optimizer.step()

            # Epoch done, save models
            self.save_models()

        # Save losses for analysis
        np.save(self.output_path + 'losses/G_X_losses.npy',
                np.array(G_X_losses))
        np.save(self.output_path + 'losses/G_Y_losses.npy',
                np.array(G_Y_losses))
        np.save(self.output_path + 'losses/D_X_losses.npy',
                np.array(D_X_losses))
        np.save(self.output_path + 'losses/D_Y_losses.npy',
                np.array(D_Y_losses))
Ejemplo n.º 10
0
class Seq2SeqCycleGAN:
    def __init__(self,
                 model_config,
                 train_config,
                 vocab,
                 max_len,
                 mode='train'):
        self.mode = mode

        self.model_config = model_config
        self.train_config = train_config

        self.vocab = vocab
        self.vocab_size = self.vocab.num_words
        self.max_len = max_len

        # self.embedding_layer = nn.Embedding(vocab_size, model_config['embedding_size'], padding_idx=PAD_token)
        self.embedding_layer = nn.Sequential(
            nn.Linear(self.vocab_size, self.model_config['embedding_size']),
            nn.Sigmoid())

        self.G_AtoB = Generator(self.embedding_layer,
                                self.model_config,
                                self.train_config,
                                self.vocab_size,
                                self.max_len,
                                mode=self.mode).cuda()
        self.G_BtoA = Generator(self.embedding_layer,
                                self.model_config,
                                self.train_config,
                                self.vocab_size,
                                self.max_len,
                                mode=self.mode).cuda()

        if self.mode == 'train':
            self.D_B = Discriminator(self.embedding_layer, self.model_config,
                                     self.train_config).cuda()
            self.D_A = Discriminator(self.embedding_layer, self.model_config,
                                     self.train_config).cuda()

            if self.train_config['continue_train']:
                self.embedding_layer.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_embedding_layer.pth'))
                self.G_AtoB.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_G_AtoB.pth'))
                self.G_BtoA.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_G_BtoA.pth'))
                self.D_B.load_state_dict(
                    torch.load(self.train_config['which_epoch'] + '_D_B.pth'))
                self.D_A.load_state_dict(
                    torch.load(self.train_config['which_epoch'] + '_D_A.pth'))

            self.embedding_layer.train()
            self.G_AtoB.train()
            self.G_BtoA.train()
            self.D_B.train()
            self.D_A.train()

            self.criterionBCE = nn.BCELoss().cuda()
            self.criterionCE = nn.CrossEntropyLoss().cuda()

            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.embedding_layer.parameters(), self.G_AtoB.parameters(),
                self.G_BtoA.parameters()),
                                                lr=train_config['base_lr'],
                                                betas=(0.9, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.embedding_layer.parameters(), self.D_A.parameters(),
                self.D_B.parameters()),
                                                lr=train_config['base_lr'],
                                                betas=(0.9, 0.999))

            self.real_label = torch.ones(
                (train_config['batch_size'], 1)).cuda()
            self.fake_label = torch.zeros(
                (train_config['batch_size'], 1)).cuda()
        else:
            self.embedding_layer.load_state_dict(
                torch.load(self.train_config['which_epoch'] +
                           '_embedding_layer.pth'))
            self.G_AtoB.load_state_dict(
                torch.load(self.train_config['which_epoch'] + '_G_AtoB.pth'))
            self.G_BtoA.load_state_dict(
                torch.load(self.train_config['which_epoch'] + '_G_BtoA.pth'))

            self.embedding_layer.eval()
            self.G_AtoB.eval()
            self.G_BtoA.eval()

    def backward_D_basic(self, netD, real, real_addn_feats, fake,
                         fake_addn_feats):
        netD.hidden = netD.init_hidden()
        pred_real = netD(real, real_addn_feats)
        loss_D_real = self.criterionBCE(pred_real, self.real_label)

        netD.hidden = netD.init_hidden()
        pred_fake = netD(fake.detach(), fake_addn_feats)
        loss_D_fake = self.criterionBCE(pred_fake, self.fake_label)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()

        self.clip_gradient(self.embedding_layer)
        self.clip_gradient(netD)

        return loss_D

    def backward_D_A(self):
        self.loss_D_A = self.backward_D_basic(
            self.D_A, self.real_A, self.real_A_addn_feats, self.fake_A,
            self.fake_A_addn_feats) * 10

    def backward_D_B(self):
        self.loss_D_B = self.backward_D_basic(
            self.D_B, self.real_B, self.real_B_addn_feats, self.fake_B,
            self.fake_B_addn_feats) * 10

    def backward_G(self):
        self.D_B.hidden = self.D_B.init_hidden()
        self.fake_B_addn_feats = get_addn_feats(self.fake_B, self.vocab).cuda()
        self.loss_G_AtoB = self.criterionBCE(
            self.D_B(self.fake_B, self.fake_B_addn_feats),
            self.real_label) * 10

        self.D_A.hidden = self.D_A.init_hidden()
        self.fake_A_addn_feats = get_addn_feats(self.fake_A, self.vocab).cuda()
        self.loss_G_BtoA = self.criterionBCE(
            self.D_A(self.fake_A, self.fake_A_addn_feats),
            self.real_label) * 10

        if self.rec_A.size(0) != self.real_A_label.size(0):
            self.real_A, self.rec_A, self.real_A_label = self.update_label_sizes(
                self.real_A, self.rec_A, self.real_A_label)
        self.loss_cycle_A = self.criterionCE(self.rec_A,
                                             self.real_A_label)  #* lambda_A

        if self.rec_B.size(0) != self.real_B_label.size(0):
            self.real_B, self.rec_B, self.real_B_label = self.update_label_sizes(
                self.real_B, self.rec_B, self.real_B_label)
        self.loss_cycle_B = self.criterionCE(self.rec_B,
                                             self.real_B_label)  #* lambda_B

        self.idt_B = self.G_AtoB(self.real_B)
        if self.idt_B.size(0) != self.real_B_label.size(0):
            self.real_B, self.idt_B, self.real_B_label = self.update_label_sizes(
                self.real_B, self.idt_B, self.real_B_label)
        self.loss_idt_B = self.criterionCE(
            self.idt_B, self.real_B_label)  #* lambda_B * lambda_idt

        self.idt_A = self.G_BtoA(self.real_A)
        if self.idt_A.size(0) != self.real_A_label.size(0):
            self.real_A, self.idt_A, self.real_A_label = self.update_label_sizes(
                self.real_A, self.idt_A, self.real_A_label)
        self.loss_idt_A = self.criterionCE(
            self.idt_A, self.real_A_label)  #* lambda_A * lambda_idt

        self.loss_G = self.loss_G_AtoB + self.loss_G_BtoA + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

        self.clip_gradient(self.embedding_layer)
        self.clip_gradient(self.G_AtoB)
        self.clip_gradient(self.G_BtoA)

    def forward(self, real_A, real_A_addn_feats, real_B, real_B_addn_feats):
        self.real_A = real_A
        self.real_A_addn_feats = real_A_addn_feats
        self.real_A_label = self.real_A.max(dim=1)[1]

        self.real_B = real_B
        self.real_B_addn_feats = real_B_addn_feats
        self.real_B_label = self.real_B.max(dim=1)[1]

        self.fake_B = F.softmax(self.G_AtoB.forward(self.real_A), dim=1)
        self.fake_A = F.softmax(self.G_BtoA.forward(self.real_B), dim=1)

        if self.mode == 'train':
            self.rec_A = self.G_BtoA.forward(self.fake_B)
            self.rec_B = self.G_AtoB.forward(self.fake_A)

        else:
            real_A_list = self.real_A.max(dim=1)[1].tolist()
            real_B_list = self.real_B.max(dim=1)[1].tolist()

            fake_B_list = self.fake_B.max(dim=1)[1].tolist()
            fake_A_list = self.fake_A.max(dim=1)[1].tolist()

            print('Input (Shakespeare):', idx_to_sent(real_A_list, self.vocab))
            print('Output (Modern):', idx_to_sent(fake_B_list, self.vocab))
            print('\n')
            print('Input (Modern):', idx_to_sent(real_B_list, self.vocab))
            print('Output (Shakespeare):', idx_to_sent(fake_A_list,
                                                       self.vocab))
            print('\n')

    def optimize_parameters(self):
        self.set_requires_grad([self.D_A, self.D_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

        self.set_requires_grad([self.D_A, self.D_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_B()
        self.backward_D_A()
        self.optimizer_D.step()

    def update_label_sizes(self, real, rec, real_label):

        if rec.size(0) > real.size(0):
            real_label = torch.cat(
                (real_label, torch.zeros((rec.size(0) - real.size(0))).type(
                    torch.LongTensor).cuda()), 0)
        elif rec.size(0) < real.size(0):
            diff = real.size(0) - rec.size(0)
            to_concat = torch.zeros((diff, self.vocab_size)).cuda()
            to_concat[:, 0] = 1
            rec = torch.cat((rec, to_concat), 0)

        return real, rec, real_label

    def indices_to_one_hot(self, idx_tensor):
        one_hot_tensor = torch.empty((idx_tensor.size(0), self.vocab_size))
        for idx in range(idx_tensor.size(0)):
            zeros = torch.zeros((self.vocab_size))
            zeros[idx_tensor[idx].item()] = 1.0
            one_hot_tensor[idx] = zeros

        return one_hot_tensor

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def clip_gradient(self, model):
        nn.utils.clip_grad_norm_(model.parameters(), 0.25)
Ejemplo n.º 11
0
def train(config):
    gpu_manage(config)

    train_dataset = Dataset(config.train_dir)
    val_dataset = Dataset(config.val_dir)
    training_data_loader = DataLoader(dataset=train_dataset,
                                      num_workers=config.threads,
                                      batch_size=config.batchsize,
                                      shuffle=True)
    val_data_loader = DataLoader(dataset=val_dataset,
                                 num_workers=config.threads,
                                 batch_size=config.test_batchsize,
                                 shuffle=False)

    gen = UNet(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids)
    if config.gen_init is not None:
        param = torch.load(config.gen_init)
        gen.load_state_dict(param)
        print('load {} as pretrained model'.format(config.gen_init))

    dis = Discriminator(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=config.gpu_ids)
    if config.dis_init is not None:
        param = torch.load(config.dis_init)
        dis.load_state_dict(param)
        print('load {} as pretrained model'.format(config.dis_init))

    opt_gen = optim.Adam(gen.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001)
    opt_dis = optim.Adam(dis.parameters(), lr=config.lr, betas=(config.beta1, 0.999), weight_decay=0.00001)

    real_a = torch.FloatTensor(config.batchsize, config.in_ch, 256, 256)
    real_b = torch.FloatTensor(config.batchsize, config.out_ch, 256, 256)

    criterionL1 = nn.L1Loss()
    criterionMSE = nn.MSELoss()
    criterionSoftplus = nn.Softplus()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if config.cuda:
        gen = gen.cuda(0)
        dis = dis.cuda(0)
        criterionL1 = criterionL1.cuda(0)
        criterionMSE = criterionMSE.cuda(0)
        criterionSoftplus = criterionSoftplus.cuda(0)
        real_a = real_a.cuda(0)
        real_b = real_b.cuda(0)

    real_a = Variable(real_a)
    real_b = Variable(real_b)

    logreport = LogReport(log_dir=config.out_dir)
    testreport = TestReport(log_dir=config.out_dir)

    for epoch in range(1, config.epoch + 1):
        print('Epoch', epoch, datetime.now())
        for iteration, batch in enumerate(tqdm(training_data_loader)):
            real_a, real_b = batch[0], batch[1]
            real_a = F.interpolate(real_a, size=256).to(device)
            real_b = F.interpolate(real_b, size=256).to(device)
            fake_b = gen.forward(real_a)

            # Update D
            opt_dis.zero_grad()

            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = dis.forward(fake_ab.detach())
            batchsize, _, w, h = pred_fake.size()

            real_ab = torch.cat((real_a, real_b), 1)
            pred_real = dis.forward(real_ab)

            loss_d_fake = torch.sum(criterionSoftplus(pred_fake)) / batchsize / w / h
            loss_d_real = torch.sum(criterionSoftplus(-pred_real)) / batchsize / w / h
            loss_d = loss_d_fake + loss_d_real
            loss_d.backward()

            if epoch % config.minimax == 0:
                opt_dis.step()

            # Update G
            opt_gen.zero_grad()
            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = dis.forward(fake_ab)

            loss_g_gan = torch.sum(criterionSoftplus(-pred_fake)) / batchsize / w / h
            loss_g = loss_g_gan + criterionL1(fake_b, real_b) * config.lamb
            loss_g.backward()

            opt_gen.step()

            if iteration % 100 == 0:
                logreport({
                    'epoch': epoch,
                    'iteration': len(training_data_loader) * (epoch - 1) + iteration,
                    'gen/loss': loss_g.item(),
                    'dis/loss': loss_d.item(),
                })

        with torch.no_grad():
            log_test = test(config, val_data_loader, gen, criterionMSE, epoch)
            testreport(log_test)

        if epoch % config.snapshot_interval == 0:
            checkpoint(config, epoch, gen, dis)

        logreport.save_lossgraph()
        testreport.save_lossgraph()

    print('Done', datetime.now())