示例#1
0
def test_adapt():
    args = parse()
    print(args)
    dataset = SepeDataset(args.poses_train,args.images_train,coor_layer_flag =False)
    dataloader = DataLoader(dataset, batch_size=1,shuffle=True ,num_workers=1,drop_last=True,worker_init_fn=lambda wid:np.random.seed(np.uint32(torch.initial_seed() + wid)))
    dataset_tgt = SepeDataset(args.poses_target,args.images_target,coor_layer_flag =False)
    dataloader_tgt = DataLoader(dataset_tgt, batch_size=1,shuffle=True ,num_workers=1,drop_last=True,worker_init_fn=lambda wid:np.random.seed(np.uint32(torch.initial_seed() + wid)))
    src_extractor = DVOFeature()
    tgt_extractor = DVOFeature()
    src_extractor.load_state_dict(torch.load(args.feature_model))
    tgt_extractor.load_state_dict(torch.load(args.feature_model))
    dvo_discriminator     = Discriminator(500,500,2)
    adapt(src_extractor,tgt_extractor,dvo_discriminator,dataloader,dataloader_tgt,args)
    torch.save(tgt_extractor.state_dict(),'tgt_feature_'+args.tag+str(args.epoch)+'.pt')
    torch.save(dvo_discriminator.state_dict(),'dis_'+args.tag+str(args.epoch)+'.pt')
示例#2
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        self.noise_n = config.noise_n
        self.G_last_act = last_act(config.G_last_act)
        self.D_out_n = config.D_out_n
        self.D_last_act = last_act(config.D_last_act)

        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.D_train_step = config.D_train_step
        self.save_image_step = config.save_image_step
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step

        self.model_save_path = config.model_save_path
        self.log_save_path = config.log_save_path
        self.image_save_path = config.image_save_path

        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.build_model()

        if self.use_tensorboard is not None:
            self.build_tensorboard()
        if self.pretrained_model is not None:
            if len(self.pretrained_model) != 2:
                raise "must have both G and D pretrained parameters, and G is first, D is second"
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.noise_n, self.G_last_act)
        self.D = Discriminator(self.D_out_n, self.D_last_act)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr,
                                            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def build_tensorboard(self):
        from commons.logger import Logger
        self.logger = Logger(self.log_save_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(self.pretrained_model[0]))
        self.D.load_state_dict(torch.load(self.pretrained_model[1]))

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def train(self):
        bce_loss = nn.BCELoss()

        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                real_label = self.to_var(
                    torch.FloatTensor(batch_size).fill_(1.))
                fake_label = self.to_var(
                    torch.FloatTensor(batch_size).fill_(0.))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = bce_loss(real_out, real_label)
                D_fake = bce_loss(fake_out, fake_label)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # Train G
                if (i + 1) % self.D_train_step == 0:
                    # noise_x = self.to_var(torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                    fake_out = self.D(self.G(noise_x))
                    G_loss = bce_loss(fake_out, real_label)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(32, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    fake_image.data,
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
示例#3
0
def main_worker(gpu, ngpus_per_node, args):
    if len(args.gpu) == 1:
        args.gpu = 0
    else:
        args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl',
                                init_method='tcp://127.0.0.1:' + args.port,
                                world_size=args.world_size,
                                rank=args.rank)

    ################
    # Define model #
    ################
    # 4/3 : scale factor in the paper
    scale_factor = 4 / 3
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    discriminator = Discriminator()
    generator = Generator(args.img_size_min, args.num_scale, scale_factor)

    networks = [discriminator, generator]

    if args.distributed:
        if args.gpu is not None:
            print('Distributed to', args.gpu)
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            networks = [
                torch.nn.parallel.DistributedDataParallel(
                    x, device_ids=[args.gpu], output_device=args.gpu)
                for x in networks
            ]
        else:
            networks = [x.cuda() for x in networks]
            networks = [
                torch.nn.parallel.DistributedDataParallel(x) for x in networks
            ]

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        networks = [x.cuda(args.gpu) for x in networks]
    else:
        networks = [torch.nn.DataParallel(x).cuda() for x in networks]

    discriminator, generator, = networks

    ######################
    # Loss and Optimizer #
    ######################
    if args.distributed:
        d_opt = torch.optim.Adam(
            discriminator.module.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(
            generator.module.sub_generators[0].parameters(), 5e-4,
            (0.5, 0.999))
    else:
        d_opt = torch.optim.Adam(
            discriminator.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(generator.sub_generators[0].parameters(),
                                 5e-4, (0.5, 0.999))

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = torch.load(load_file, map_location='cpu')
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            networks = [discriminator, generator]

            if args.distributed:
                if args.gpu is not None:
                    print('Distributed to', args.gpu)
                    torch.cuda.set_device(args.gpu)
                    networks = [x.cuda(args.gpu) for x in networks]
                    args.batch_size = int(args.batch_size / ngpus_per_node)
                    args.workers = int(args.workers / ngpus_per_node)
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(
                            x, device_ids=[args.gpu], output_device=args.gpu)
                        for x in networks
                    ]
                else:
                    networks = [x.cuda() for x in networks]
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(x)
                        for x in networks
                    ]

            elif args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
            else:
                networks = [torch.nn.DataParallel(x).cuda() for x in networks]

            discriminator, generator, = networks

            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    cudnn.benchmark = True

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    ######################
    # Validate and Train #
    ######################
    z_fix_list = [
        F.pad(torch.randn(args.batch_size, 3, args.size_list[0],
                          args.size_list[0]), [5, 5, 5, 5],
              value=0)
    ]
    zero_list = [
        F.pad(torch.zeros(args.batch_size, 3, args.size_list[zeros_idx],
                          args.size_list[zeros_idx]), [5, 5, 5, 5],
              value=0) for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
        record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
        record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
        record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
        record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
        record_txt.close()

    for stage in range(args.stage, args.num_scale + 1):
        if args.distributed:
            train_sampler.set_epoch(stage)

        trainSinGAN(train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        validateSinGAN(train_loader, networks, stage, args,
                       {"z_rec": z_fix_list})

        if args.distributed:
            discriminator.module.progress()
            generator.module.progress()
        else:
            discriminator.progress()
            generator.progress()

        networks = [discriminator, generator]

        if args.distributed:
            if args.gpu is not None:
                print('Distributed', args.gpu)
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int(args.workers / ngpus_per_node)
                networks = [
                    torch.nn.parallel.DistributedDataParallel(
                        x, device_ids=[args.gpu], output_device=args.gpu)
                    for x in networks
                ]
            else:
                networks = [x.cuda() for x in networks]
                networks = [
                    torch.nn.parallel.DistributedDataParallel(x)
                    for x in networks
                ]

        elif args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
        else:
            networks = [torch.nn.DataParallel(x).cuda() for x in networks]

        discriminator, generator, = networks

        # Update the networks at finest scale
        if args.distributed:
            for net_idx in range(generator.module.current_scale):
                for param in generator.module.sub_generators[
                        net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.module.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.module.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.module.sub_generators[
                    generator.current_scale].parameters(), 5e-4, (0.5, 0.999))
        else:
            for net_idx in range(generator.current_scale):
                for param in generator.sub_generators[net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.sub_generators[generator.current_scale].parameters(),
                5e-4, (0.5, 0.999))

        ##############
        # Save model #
        ##############
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if stage == 0:
                check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                                  "a+")
            save_checkpoint(
                {
                    'stage': stage + 1,
                    'D_state_dict': discriminator.state_dict(),
                    'G_state_dict': generator.state_dict(),
                    'd_optimizer': d_opt.state_dict(),
                    'g_optimizer': g_opt.state_dict(),
                    'img_to_use': args.img_to_use
                }, check_list, args.log_dir, stage + 1)
            if stage == args.num_scale:
                check_list.close()
示例#4
0
def main_worker(args):

    ################
    # Define model #
    ################
    # 4/3 : scale factor in the paper
    scale_factor = 4 / 3
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    discriminator = Discriminator()
    generator = Generator(args.img_size_min, args.num_scale, scale_factor)

    ######################
    # Loss and Optimizer #
    ######################
    d_opt = mindspore.nn.Adam(
        discriminator.sub_discriminators[0].get_parameters(), 5e-4, 0.5, 0.999)
    g_opt = mindspore.nn.Adam(generator.sub_generators[0].get_parameters(),
                              5e-4, 0.5, 0.999)

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = mindspore.load_checkpoint(
                load_file)  # MPS map_location='cpu'#
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            # MPS Adm.load_state_dict是否存在
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)
    train_sampler = None

    train_loader = mindspore.DatasetHelper(train_dataset)  # MPS 可能需要调参数

    ######################
    # Validate and Train #
    ######################
    op1 = mindspore.ops.Pad(((5, 5), (5, 5)))
    op2 = mindspore.ops.Pad(((5, 5), (5, 5)))
    z_fix_list = [op1(mindspore.ops.StandardNormal(3, args.size_list[0]))]
    zero_list = [
        op2(mindspore.ops.Zeros(3, args.size_list[zeros_idx]))
        for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
    record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
    record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
    record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
    record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
    record_txt.close()
    networks = [discriminator, generator]

    for stage in range(args.stage, args.num_scale + 1):

        trainSinGAN(train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        validateSinGAN(train_loader, networks, stage, args,
                       {"z_rec": z_fix_list})
        discriminator.progress()
        generator.progress()

        # Update the networks at finest scale
        d_opt = mindspore.nn.Adam(
            discriminator.sub_discriminators[
                discriminator.current_scale].parameters(), 5e-4, 0.5, 0.999)
        g_opt = mindspore.nn.Adam(
            generator.sub_generators[generator.current_scale].parameters(),
            5e-4, 0.5, 0.999)
        ##############
        # Save model #
        ##############
        if stage == 0:
            check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                              "a+")
        save_checkpoint(
            {
                'stage': stage + 1,
                'D_state_dict': discriminator.state_dict(),
                'G_state_dict': generator.state_dict(),
                'd_optimizer': d_opt.state_dict(),
                'g_optimizer': g_opt.state_dict(),
                'img_to_use': args.img_to_use
            }, check_list, args.log_dir, stage + 1)
        if stage == args.num_scale:
            check_list.close()
class GanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(GanTrainer, self).__init__(train_loader, test_loader,
                                         valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the auto-encoder
        self.use_autoencoder = False
        if trainer_args.autoencoder_path and os.path.exists(
                trainer_args.autoencoder_path):
            self.use_autoencoder = True
            self.autoencoder = AutoEncoder(general_args=general_args).to(
                self.device)
            self.load_pretrained_autoencoder(trainer_args.autoencoder_path)
            self.autoencoder.eval()

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.generator_time_criterion = nn.MSELoss()
        self.generator_frequency_criterion = nn.MSELoss()
        self.generator_autoencoder_criterion = nn.MSELoss()

        # Define labels
        self.real_label = 1
        self.generated_label = 0

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_freq = trainer_args.lambda_freq
        self.lambda_autoencoder = trainer_args.lambda_autoencoder

        # Spectrogram converter
        self.spectrogram = Spectrogram(normalized=True).to(self.device)

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Boolean if the generator receives the feedback from the discriminator
        self.use_adversarial = trainer_args.use_adversarial

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def load_pretrained_autoencoder(self, autoencoder_path):
        """
        Loads a pre-trained auto-encoder. Can be used to infer
        :param autoencoder_path: location of the pre-trained auto-encoder (string).
        :return: None
        """
        checkpoint = torch.load(autoencoder_path, map_location=self.device)
        self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])

    def train(self, epochs):
        """
        Trains the GAN for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                self.generator.train()
                self.discriminator.train()
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)
                batch_size = input_batch.shape[0]

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                # Train the discriminator with real data
                self.discriminator_optimizer.zero_grad()
                label = torch.full((batch_size, ),
                                   self.real_label,
                                   device=self.device)
                output = self.discriminator(target_batch)

                # Compute and store the discriminator loss on real data
                loss_discriminator_real = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['real'].append(
                    loss_discriminator_real.item())
                loss_discriminator_real.backward()

                # Train the discriminator with fake data
                generated_batch = self.generator(input_batch)
                label.fill_(self.generated_label)
                output = self.discriminator(generated_batch.detach())

                # Compute and store the discriminator loss on fake data
                loss_discriminator_generated = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['fake'].append(
                    loss_discriminator_generated.item())
                loss_discriminator_generated.backward()

                # Update the discriminator weights
                self.discriminator_optimizer.step()

                ############################
                # Update G network: maximize log(D(G(z)))
                ###########################
                self.generator_optimizer.zero_grad()

                # Get the spectrogram
                specgram_target_batch = self.spectrogram(target_batch)
                specgram_fake_batch = self.spectrogram(generated_batch)

                # Fake labels are real for the generator cost
                label.fill_(self.real_label)
                output = self.discriminator(generated_batch)

                # Compute the generator loss on fake data
                # Get the adversarial loss
                loss_generator_adversarial = torch.zeros(size=[1],
                                                         device=self.device)
                if self.use_adversarial:
                    loss_generator_adversarial = self.adversarial_criterion(
                        output, torch.unsqueeze(label, dim=1))
                self.train_losses['generator_adversarial'].append(
                    loss_generator_adversarial.item())

                # Get the L2 loss in time domain
                loss_generator_time = self.generator_time_criterion(
                    generated_batch, target_batch)
                self.train_losses['time_l2'].append(loss_generator_time.item())

                # Get the L2 loss in frequency domain
                loss_generator_frequency = self.generator_frequency_criterion(
                    specgram_fake_batch, specgram_target_batch)
                self.train_losses['freq_l2'].append(
                    loss_generator_frequency.item())

                # Get the L2 loss in embedding space
                loss_generator_autoencoder = torch.zeros(size=[1],
                                                         device=self.device,
                                                         requires_grad=True)
                if self.use_autoencoder:
                    # Get the embeddings
                    _, embedding_target_batch = self.autoencoder(target_batch)
                    _, embedding_generated_batch = self.autoencoder(
                        generated_batch)
                    loss_generator_autoencoder = self.generator_autoencoder_criterion(
                        embedding_generated_batch, embedding_target_batch)
                    self.train_losses['autoencoder_l2'].append(
                        loss_generator_autoencoder.item())

                # Combine the different losses
                loss_generator = self.lambda_adv * loss_generator_adversarial + loss_generator_time + \
                                 self.lambda_freq * loss_generator_frequency + \
                                 self.lambda_autoencoder * loss_generator_autoencoder

                # Back-propagate and update the generator weights
                loss_generator.backward()
                self.generator_optimizer.step()

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Frequency: {} \n' \
                              '\t\t Autoencoder {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Real {} \n' \
                              '\t\t Fake {} \n'.format(i,
                                                       loss_generator_time.item(),
                                                       loss_generator_frequency.item(),
                                                       loss_generator_autoencoder.item(),
                                                       loss_generator_adversarial.item(),
                                                       loss_discriminator_real.item(),
                                                       loss_discriminator_generated.item())
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()
            # if self.need_saving:
            #     self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': [], 'freq_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            # Get the spectrogram
            specgram_target_batch = self.spectrogram(target_batch)
            specgram_generated_batch = self.spectrogram(generated_batch)

            loss_generator_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_generator_time.item())
            loss_generator_frequency = self.generator_frequency_criterion(
                specgram_generated_batch, specgram_target_batch)
            batch_losses['freq_l2'].append(loss_generator_frequency.item())

        # Store the validation losses
        self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2']))
        self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n' \
                  '\t Frequency: {} \n'.format(self.epoch,
                                               np.mean(np.mean(batch_losses['time_l2'])),
                                               np.mean(np.mean(batch_losses['freq_l2'])))
        print(message)

        # Check if the loss is decreasing
        self.check_improvement()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, self.savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(
            checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        self.discriminator_optimizer.load_state_dict(
            checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
示例#6
0
def train():
    torch.manual_seed(1337)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Config
    batch_size = 32
    image_size = 256
    learning_rate = 1e-4
    beta1, beta2 = (.5, .99)
    weight_decay = 1e-4
    epochs = 1000

    # Models
    netD = Discriminator().to(device)
    netG = Generator().to(device)
    # Here you should load the pretrained G
    netG.load_state_dict(torch.load("./checkpoints/pretrained_netG.pth").state_dict())

    optimizerD = AdamW(netD.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)
    optimizerG = AdamW(netG.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)

    scaler = torch.cuda.amp.GradScaler()

    # Labels
    cartoon_labels = torch.ones (batch_size, 1, image_size // 4, image_size // 4).to(device)
    fake_labels    = torch.zeros(batch_size, 1, image_size // 4, image_size // 4).to(device)

    # Loss functions
    content_loss = ContentLoss().to(device)
    adv_loss     = AdversialLoss(cartoon_labels, fake_labels).to(device)
    BCE_loss     = nn.BCEWithLogitsLoss().to(device)

    # Dataloaders
    real_dataloader    = get_dataloader("./datasets/real_images/flickr30k_images/",           size = image_size, bs = batch_size)
    cartoon_dataloader = get_dataloader("./datasets/cartoon_images_smoothed/Studio Ghibli",   size = image_size, bs = batch_size, trfs=get_pair_transforms(image_size))

    # --------------------------------------------------------------------------------------------- #
    # Training Loop

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    tracked_images = next(iter(real_dataloader)).to(device)

    print("Starting Training Loop...")
    # For each epoch.
    for epoch in range(epochs):
        print("training epoch ", epoch)
        # For each batch in the dataloader.
        for i, (cartoon_edge_data, real_data) in enumerate(zip(cartoon_dataloader, real_dataloader)):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            
            # Reset Discriminator gradient.
            netD.zero_grad()
            for param in netD.parameters():
                param.requires_grad = True

            # Format batch.
            cartoon_data   = cartoon_edge_data[:, :, :, :image_size].to(device)
            edge_data      = cartoon_edge_data[:, :, :, image_size:].to(device)
            real_data      = real_data.to(device)

            with torch.cuda.amp.autocast():
                # Generate image
                generated_data = netG(real_data)

                # Forward pass all batches through D.
                cartoon_pred   = netD(cartoon_data)      #.view(-1)
                edge_pred      = netD(edge_data)         #.view(-1)
                generated_pred = netD(generated_data.detach())    #.view(-1)

                # Calculate discriminator loss on all batches.
                errD = adv_loss(cartoon_pred, generated_pred, edge_pred)
            
            # Calculate gradients for D in backward pass
            scaler.scale(errD).backward()
            D_x = cartoon_pred.mean().item() # Should be close to 1

            # Update D
            scaler.step(optimizerD)


            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            
            # Reset Generator gradient.
            netG.zero_grad()
            for param in netD.parameters():
                param.requires_grad = False

            with torch.cuda.amp.autocast():
                # Since we just updated D, perform another forward pass of all-fake batch through D
                generated_pred = netD(generated_data) #.view(-1)

                # Calculate G's loss based on this output
                errG = BCE_loss(generated_pred, cartoon_labels) + content_loss(generated_data, real_data)

            # Calculate gradients for G
            scaler.scale(errG).backward()

            D_G_z2 = generated_pred.mean().item() # Should be close to 1
            
            # Update G
            scaler.step(optimizerG)

            scaler.update()
            
            # ---------------------------------------------------------------------------------------- #

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on tracked_images
            if iters % 200 == 0:
                with torch.no_grad():
                    fake = netG(tracked_images)
                vutils.save_image(unnormalize(fake), f"images/{epoch}_{i}.png", padding=2)
                with open("images/log.txt", "a+") as f:
                    f.write(f"{datetime.now().isoformat(' ', 'seconds')}\tD: {np.mean(D_losses)}\tG: {np.mean(G_losses)}\n")
                D_losses = []
                G_losses = []

            if iters % 1000 == 0:
                torch.save(netG.state_dict(), f"checkpoints/netG_e{epoch}_i{iters}_l{errG.item()}.pth")
                torch.save(netD.state_dict(), f"checkpoints/netD_e{epoch}_i{iters}_l{errG.item()}.pth")

            iters += 1
示例#7
0
def main():
    env = DialogEnvironment()
    experiment_name = args.logdir.split('/')[1] #model name

    torch.manual_seed(args.seed)

    #TODO
    actor = Actor(hidden_size=args.hidden_size,num_layers=args.num_layers,device='cuda',input_size=args.input_size,output_size=args.input_size)
    critic = Critic(hidden_size=args.hidden_size,num_layers=args.num_layers,input_size=args.input_size,seq_len=args.seq_len)
    discrim = Discriminator(hidden_size=args.hidden_size,num_layers=args.hidden_size,input_size=args.input_size,seq_len=args.seq_len)
    
    actor.to(device), critic.to(device), discrim.to(device)
    
    actor_optim = optim.Adam(actor.parameters(), lr=args.learning_rate)
    critic_optim = optim.Adam(critic.parameters(), lr=args.learning_rate, 
                              weight_decay=args.l2_rate) 
    discrim_optim = optim.Adam(discrim.parameters(), lr=args.learning_rate)

    # load demonstrations

    writer = SummaryWriter(args.logdir)

    if args.load_model is not None: #TODO
        saved_ckpt_path = os.path.join(os.getcwd(), 'save_model', str(args.load_model))
        ckpt = torch.load(saved_ckpt_path)

        actor.load_state_dict(ckpt['actor'])
        critic.load_state_dict(ckpt['critic'])
        discrim.load_state_dict(ckpt['discrim'])


    
    episodes = 0
    train_discrim_flag = True

    for iter in range(args.max_iter_num):
        actor.eval(), critic.eval()
        memory = deque()

        steps = 0
        scores = []
        similarity_scores = []
        while steps < args.total_sample_size: 
            scores = []
            similarity_scores = []
            state, expert_action, raw_state, raw_expert_action = env.reset()
            score = 0
            similarity_score = 0
            state = state[:args.seq_len,:]
            expert_action = expert_action[:args.seq_len,:]
            state = state.to(device)
            expert_action = expert_action.to(device)
            for _ in range(10000): 

                steps += 1

                mu, std = actor(state.resize(1,args.seq_len,args.input_size)) #TODO: gotta be a better way to resize. 
                action = get_action(mu.cpu(), std.cpu())[0]
                for i in range(5):
                    emb_sum = expert_action[i,:].sum().cpu().item()
                    if emb_sum == 0:
                       # print(i)
                        action[i:,:] = 0 # manual padding
                        break

                done= env.step(action)
                irl_reward = get_reward(discrim, state, action, args)
                if done:
                    mask = 0
                else:
                    mask = 1


                memory.append([state, torch.from_numpy(action).to(device), irl_reward, mask,expert_action])
                score += irl_reward
                similarity_score += get_cosine_sim(expert=expert_action,action=action.squeeze(),seq_len=5)
                #print(get_cosine_sim(s1=expert_action,s2=action.squeeze(),seq_len=5),'sim')
                if done:
                    break

            episodes += 1
            scores.append(score)
            similarity_scores.append(similarity_score)

        score_avg = np.mean(scores)
        similarity_score_avg = np.mean(similarity_scores)
        print('{}:: {} episode score is {:.2f}'.format(iter, episodes, score_avg))
        print('{}:: {} episode similarity score is {:.2f}'.format(iter, episodes, similarity_score_avg))

        actor.train(), critic.train(), discrim.train()
        if train_discrim_flag:
            expert_acc, learner_acc = train_discrim(discrim, memory, discrim_optim, args) 
            print("Expert: %.2f%% | Learner: %.2f%%" % (expert_acc * 100, learner_acc * 100))
            writer.add_scalar('log/expert_acc', float(expert_acc), iter) #logg
            writer.add_scalar('log/learner_acc', float(learner_acc), iter) #logg
            writer.add_scalar('log/avg_acc', float(learner_acc + expert_acc)/2, iter) #logg
            if args.suspend_accu_exp is not None: #only if not None do we check.
                if expert_acc > args.suspend_accu_exp and learner_acc > args.suspend_accu_gen:
                    train_discrim_flag = False

        train_actor_critic(actor, critic, memory, actor_optim, critic_optim, args)
        writer.add_scalar('log/score', float(score_avg), iter)
        writer.add_scalar('log/similarity_score', float(similarity_score_avg), iter)
        writer.add_text('log/raw_state', raw_state[0],iter)
        raw_action = get_raw_action(action) #TODO
        writer.add_text('log/raw_action', raw_action,iter)
        writer.add_text('log/raw_expert_action', raw_expert_action,iter)

        if iter % 100:
            score_avg = int(score_avg)
            # Open a file with access mode 'a'
            file_object = open(experiment_name+'.txt', 'a')

            result_str = str(iter) + '|' + raw_state[0] + '|' + raw_action + '|' + raw_expert_action + '\n'
            # Append at the end of file
            file_object.write(result_str)
            # Close the file
            file_object.close()

            model_path = os.path.join(os.getcwd(),'save_model')
            if not os.path.isdir(model_path):
                os.makedirs(model_path)

            ckpt_path = os.path.join(model_path, experiment_name + '_ckpt_'+ str(score_avg)+'.pth.tar')

            save_checkpoint({
                'actor': actor.state_dict(),
                'critic': critic.state_dict(),
                'discrim': discrim.state_dict(),
                'args': args,
                'score': score_avg,
            }, filename=ckpt_path)
示例#8
0
class Model(object):
    def __init__(self, opt):
        super(Model, self).__init__()

        # Generator
        self.gen = Generator(opt).cuda(opt.gpu_id)

        self.gen_params = self.gen.parameters()

        num_params = 0
        for p in self.gen.parameters():
            num_params += p.numel()
        print(self.gen)
        print(num_params)

        # Discriminator
        self.dis = Discriminator(opt).cuda(opt.gpu_id)

        self.dis_params = self.dis.parameters()

        num_params = 0
        for p in self.dis.parameters():
            num_params += p.numel()
        print(self.dis)
        print(num_params)

        # Regressor
        if opt.mse_weight:
            self.reg = torch.load('data/utils/classifier.pth').cuda(
                opt.gpu_id).eval()
        else:
            self.reg = None

        # Losses
        self.criterion_gan = GANLoss(opt, self.dis)
        self.criterion_mse = lambda x, y: l1_loss(x, y) * opt.mse_weight

        self.loss_mse = Variable(torch.zeros(1).cuda())
        self.loss_adv = Variable(torch.zeros(1).cuda())
        self.loss = Variable(torch.zeros(1).cuda())

        self.path = opt.experiments_dir + opt.experiment_name + '/checkpoints/'
        self.gpu_id = opt.gpu_id
        self.noise_channels = opt.in_channels - len(opt.input_idx.split(','))

    def forward(self, inputs):

        input, input_orig, target = inputs

        self.input = Variable(input.cuda(self.gpu_id))
        self.input_orig = Variable(input_orig.cuda(self.gpu_id))
        self.target = Variable(target.cuda(self.gpu_id))

        noise = Variable(
            torch.randn(self.input.size(0),
                        self.noise_channels).cuda(self.gpu_id))

        self.fake = self.gen(torch.cat([self.input, noise], 1))

    def backward_G(self):

        # Regressor loss
        if self.reg is not None:

            fake_input = self.reg(self.fake)

            self.loss_mse = self.criterion_mse(fake_input, self.input_orig)

        # GAN loss
        loss_adv, _ = self.criterion_gan(self.fake)

        loss_G = self.loss_mse + loss_adv
        loss_G.backward()

    def backward_D(self):

        loss_adv, self.loss_adv = self.criterion_gan(self.target, self.fake)

        loss_D = loss_adv
        loss_D.backward()

    def train(self):

        self.gen.train()
        self.dis.train()

    def eval(self):

        self.gen.eval()
        self.dis.eval()

    def save_checkpoint(self, epoch):

        torch.save(
            {
                'epoch': epoch,
                'gen_state_dict': self.gen.state_dict(),
                'dis_state_dict': self.dis.state_dict()
            }, self.path + '%d.pkl' % epoch)

    def load_checkpoint(self, path, pretrained=True):

        weights = torch.load(path)

        self.gen.load_state_dict(weights['gen_state_dict'])
        self.dis.load_state_dict(weights['dis_state_dict'])
示例#9
0
class Solver(object):
    def __init__(self, data_loader, config):

        self.data_loader = data_loader

        self.noise_n = config.noise_n
        self.G_last_act = last_act(config.G_last_act)
        self.D_out_n = config.D_out_n
        self.D_last_act = last_act(config.D_last_act)

        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.epoch = config.epoch
        self.batch_size = config.batch_size
        self.D_train_step = config.D_train_step
        self.save_image_step = config.save_image_step
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step
        self.clip_value = config.clip_value
        self.lambda_gp = config.lambda_gp

        self.model_save_path = config.model_save_path
        self.log_save_path = config.log_save_path
        self.image_save_path = config.image_save_path

        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model
        self.build_model()

        if self.use_tensorboard is not None:
            self.build_tensorboard()
        if self.pretrained_model is not None:
            if len(self.pretrained_model) != 2:
                raise "must have both G and D pretrained parameters, and G is first, D is second"
            self.load_pretrained_model()

    def build_model(self):
        self.G = Generator(self.noise_n, self.G_last_act)
        self.D = Discriminator(self.D_out_n, self.D_last_act)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.G_lr,
                                            [self.beta1, self.beta2])
        self.D_optimizer = torch.optim.Adam(self.D.parameters(), self.D_lr,
                                            [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.G.cuda()
            self.D.cuda()

    def build_tensorboard(self):
        from commons.logger import Logger
        self.logger = Logger(self.log_save_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(self.pretrained_model[0]))
        self.D.load_state_dict(torch.load(self.pretrained_model[1]))

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def reset_grad(self):
        self.G_optimizer.zero_grad()
        self.D_optimizer.zero_grad()

    def to_var(self, x, volatile=False):
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x, volatile=volatile)

    def train(self):
        print(len(self.data_loader))
        for e in range(self.epoch):
            for i, batch_images in enumerate(self.data_loader):
                batch_size = batch_images.size(0)
                label = torch.FloatTensor(batch_size)
                real_x = self.to_var(batch_images)
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(batch_size, self.noise_n)))
                # train D
                fake_x = self.G(noise_x)
                real_out = self.D(real_x)
                fake_out = self.D(fake_x.detach())

                D_real = -torch.mean(real_out)
                D_fake = torch.mean(fake_out)
                D_loss = D_real + D_fake

                self.reset_grad()
                D_loss.backward()
                self.D_optimizer.step()
                # Log
                loss = {}
                loss['D/loss_real'] = D_real.data[0]
                loss['D/loss_fake'] = D_fake.data[0]
                loss['D/loss'] = D_loss.data[0]

                # choose one in below two
                # Clip weights of D
                # for p in self.D.parameters():
                #     p.data.clamp_(-self.clip_value, clip_value)
                # Gradients penalty, WGAP-GP
                alpha = torch.rand(real_x.size(0), 1, 1,
                                   1).cuda().expand_as(real_x)
                # print(alpha.shape, real_x.shape, fake_x.shape)
                interpolated = Variable(alpha * real_x.data +
                                        (1 - alpha) * fake_x.data,
                                        requires_grad=True)
                gp_out = self.D(interpolated)
                grad = torch.autograd.grad(outputs=gp_out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(
                                               gp_out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad**2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1)**2)
                # Backward + Optimize
                d_loss = self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.D_optimizer.step()
                # Train G
                if (i + 1) % self.D_train_step == 0:
                    fake_out = self.D(self.G(noise_x))
                    G_loss = -torch.mean(fake_out)
                    self.reset_grad()
                    G_loss.backward()
                    self.G_optimizer.step()
                    loss['G/loss'] = G_loss.data[0]
                # Print log
                if (i + 1) % self.log_step == 0:
                    log = "Epoch: {}/{}, Iter: {}/{}".format(
                        e + 1, self.epoch, i + 1, len(self.data_loader))
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(
                            tag, value,
                            e * len(self.data_loader) + i + 1)
            # Save images
            if (e + 1) % self.save_image_step == 0:
                noise_x = self.to_var(
                    torch.FloatTensor(noise_vector(16, self.noise_n)))
                fake_image = self.G(noise_x)
                save_image(
                    self.denorm(fake_image.data),
                    os.path.join(self.image_save_path,
                                 "{}_fake.png".format(e + 1)))
            if (e + 1) % self.model_save_step == 0:
                torch.save(
                    self.G.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_G.pth".format(e + 1)))
                torch.save(
                    self.D.state_dict(),
                    os.path.join(self.model_save_path,
                                 "{}_D.pth".format(e + 1)))
示例#10
0
class WGanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(WGanTrainer, self).__init__(train_loader, test_loader,
                                          valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.generator_time_criterion = nn.MSELoss()

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_time = trainer_args.lambda_time

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Overrides losses from parent class
        self.train_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }
        self.test_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }
        self.valid_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }

        # Select either wgan or wgan-gp method
        self.use_penalty = trainer_args.use_penalty
        self.gamma = trainer_args.gamma_wgan_gp
        self.clipping_limit = trainer_args.clipping_limit
        self.n_critic = trainer_args.n_critic
        self.coupling_epoch = trainer_args.coupling_epoch

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def compute_gradient_penalty(self, input_batch, generated_batch):
        """
        Compute the gradient penalty as described in the original paper
        (https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf).
        :param input_batch: batch of input data (torch tensor).
        :param generated_batch: batch of generated data (torch tensor).
        :return: penalty as a scalar (torch tensor).
        """
        batch_size = input_batch.size(0)
        epsilon = torch.rand(batch_size, 1, 1)
        epsilon = epsilon.expand_as(input_batch).to(self.device)

        # Interpolate
        interpolation = epsilon * input_batch.data + (
            1 - epsilon) * generated_batch.data
        interpolation = interpolation.requires_grad_(True).to(self.device)

        # Computes the discriminator's prediction for the interpolated input
        interpolation_logits = self.discriminator(interpolation)

        # Computes a vector of outputs to make it works with 2 output classes if needed
        grad_outputs = torch.ones_like(interpolation_logits).to(
            self.device).requires_grad_(True)

        # Get the gradients and retain the graph so that the penalty can be back-propagated
        gradients = autograd.grad(outputs=interpolation_logits,
                                  inputs=interpolation,
                                  grad_outputs=grad_outputs,
                                  create_graph=True,
                                  retain_graph=True,
                                  only_inputs=True)[0]
        gradients = gradients.view(batch_size, -1)

        # Computes the norm of the gradients
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1))
        return ((gradients_norm - 1)**2).mean()

    def train_discriminator_step(self, input_batch, target_batch):
        """
        Trains the discriminator for a single step based on the wasserstein gan-gp framework.
        :param input_batch: batch of input data (torch tensor).
        :param target_batch: batch of target data (torch tensor).
        :return: a batch of generated data (torch tensor).
        """
        # Activate gradient tracking for the discriminator
        self.change_discriminator_grad_requirement(requires_grad=True)

        # Set the discriminator's gradients to zero
        self.discriminator_optimizer.zero_grad()

        # Generate a batch and compute the penalty
        generated_batch = self.generator(input_batch)

        # Compute the loss
        loss_d = self.discriminator(generated_batch.detach()).mean(
        ) - self.discriminator(target_batch).mean()
        self.train_losses['discriminator']['adversarial'].append(loss_d.item())
        if self.use_penalty:
            penalty = self.compute_gradient_penalty(input_batch,
                                                    generated_batch.detach())
            self.train_losses['discriminator']['penalty'].append(
                penalty.item())
            loss_d = loss_d + self.gamma * penalty

        # Update the discriminator's weights
        loss_d.backward()
        self.discriminator_optimizer.step()

        # Apply the weight constraint if needed
        if not self.use_penalty:
            for p in self.discriminator.parameters():
                p.data.clamp_(min=-self.clipping_limit,
                              max=self.clipping_limit)

        # Return the generated batch to avoid redundant computation
        return generated_batch

    def train_generator_step(self, target_batch, generated_batch):
        """
        Trains the generator for a single step based on the wasserstein gan-gp framework.
        :param target_batch: batch of target data (torch tensor).
        :param generated_batch: batch of generated data (torch tensor).
        :return: None
        """
        # Deactivate gradient tracking for the discriminator
        self.change_discriminator_grad_requirement(requires_grad=False)

        # Set generator's gradients to zero
        self.generator_optimizer.zero_grad()

        # Get the generator losses
        loss_g_adversarial = -self.discriminator(generated_batch).mean()
        loss_g_time = self.generator_time_criterion(generated_batch,
                                                    target_batch)

        # Combine the different losses
        loss_g = loss_g_time
        if self.epoch >= self.coupling_epoch:
            loss_g = loss_g + self.lambda_adv * loss_g_adversarial

        # Back-propagate and update the generator weights
        loss_g.backward()
        self.generator_optimizer.step()

        # Store the losses
        self.train_losses['generator']['time_l2'].append(loss_g_time.item())
        self.train_losses['generator']['adversarial'].append(
            loss_g_adversarial.item())

    def change_discriminator_grad_requirement(self, requires_grad):
        """
        Changes the requires_grad flag of discriminator's parameters. This action is not absolutely needed as the
        discriminator's optimizer is not called after the generators update, but it reduces the computational cost.
        :param requires_grad: flag indicating if the discriminator's parameter require gradient tracking (boolean).
        :return: None
        """
        for p in self.discriminator.parameters():
            p.requires_grad_(requires_grad)

    def train(self, epochs):
        """
        Trains the WGAN-GP for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        self.generator.train()
        self.discriminator.train()
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Train the discriminator
                generated_batch = self.train_discriminator_step(
                    input_batch, target_batch)

                # Train the generator every n_critic
                if not (i % self.n_critic):
                    self.train_generator_step(target_batch, generated_batch)

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Penalty: {}\n' \
                              '\t\t Adversarial: {} \n'.format(i,
                                                               self.train_losses['generator']['time_l2'][-1],
                                                               self.train_losses['generator']['adversarial'][-1],
                                                               self.train_losses['discriminator']['penalty'][-1],
                                                               self.train_losses['discriminator']['adversarial'][-1])
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        # Set the models in evaluation mode
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            loss_g_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_g_time.item())

        # Store the validation losses
        self.valid_losses['generator']['time_l2'].append(
            np.mean(batch_losses['time_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n'.format(self.epoch, np.mean(np.mean(batch_losses['time_l2'])))
        print(message)

        # Set the models in train mode
        self.generator.train()
        self.discriminator.eval()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        savepath = self.savepath.split('.')[0] + '_' + str(
            self.epoch // 5) + '.' + self.savepath.split('.')[1]
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        # self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        # self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
示例#11
0
class Trainer():
    def __init__(self, config):
        self.batch_size = config.batchSize
        self.epochs = config.epochs

        self.use_cycle_loss = config.cycleLoss
        self.cycle_multiplier = config.cycleMultiplier

        self.use_identity_loss = config.identityLoss
        self.identity_multiplier = config.identityMultiplier

        self.load_models = config.loadModels
        self.data_x_loc = config.dataX
        self.data_y_loc = config.dataY

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.init_models()
        self.init_data_loaders()
        self.g_optimizer = torch.optim.Adam(list(self.G_X.parameters()) +
                                            list(self.G_Y.parameters()),
                                            lr=config.lr)
        self.d_optimizer = torch.optim.Adam(list(self.D_X.parameters()) +
                                            list(self.D_Y.parameters()),
                                            lr=config.lr)
        self.scheduler_g = torch.optim.lr_scheduler.StepLR(self.g_optimizer,
                                                           step_size=1,
                                                           gamma=0.95)

        self.output_path = "./outputs/"
        self.img_width = 256
        self.img_height = 256

    # Load/Construct the models
    def init_models(self):

        self.G_X = Generator(3, 3, nn.InstanceNorm2d)
        self.D_X = Discriminator(3)
        self.G_Y = Generator(3, 3, nn.InstanceNorm2d)
        self.D_Y = Discriminator(3)

        if self.load_models:
            self.G_X.load_state_dict(
                torch.load(self.output_path + "models/G_X",
                           map_location='cpu'))
            self.G_Y.load_state_dict(
                torch.load(self.output_path + "models/G_Y",
                           map_location='cpu'))
            self.D_X.load_state_dict(
                torch.load(self.output_path + "models/D_X",
                           map_location='cpu'))
            self.D_Y.load_state_dict(
                torch.load(self.output_path + "models/D_Y",
                           map_location='cpu'))
        else:
            self.G_X.apply(init_func)
            self.G_Y.apply(init_func)
            self.D_X.apply(init_func)
            self.D_Y.apply(init_func)

        self.G_X.to(self.device)
        self.G_Y.to(self.device)
        self.D_X.to(self.device)
        self.D_Y.to(self.device)

    # Initialize data loaders and image transformer
    def init_data_loaders(self):

        transform = transforms.Compose([
            transforms.Resize((self.img_width, self.img_height)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        X_folder = torchvision.datasets.ImageFolder(self.data_x_loc, transform)
        self.X_loader = torch.utils.data.DataLoader(X_folder,
                                                    batch_size=self.batch_size,
                                                    shuffle=True)

        Y_folder = torchvision.datasets.ImageFolder(self.data_y_loc, transform)
        self.Y_loader = torch.utils.data.DataLoader(Y_folder,
                                                    batch_size=self.batch_size,
                                                    shuffle=True)

    def save_models(self):
        torch.save(self.G_X.state_dict(), self.output_path + "models/G_X")
        torch.save(self.D_X.state_dict(), self.output_path + "models/D_X")
        torch.save(self.G_Y.state_dict(), self.output_path + "models/G_Y")
        torch.save(self.D_Y.state_dict(), self.output_path + "models/D_Y")

    # Reset gradients for all models, needed for between every training
    def reset_gradients(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    # Sample image from training data every %x epoch and save them for judging
    def save_samples(self, epoch):
        x_iter = iter(self.X_loader)
        y_iter = iter(self.Y_loader)

        img_data_x, _ = next(x_iter)
        img_data_y, _ = next(y_iter)

        original_x = np.array(img_data_x[0])
        generated_y = np.array(
            self.G_Y(img_data_x[0].view(1, 3, self.img_width,
                                        self.img_height).to(
                                            self.device)).cpu().detach())[0]

        original_y = np.array(img_data_y[0])
        generated_x = np.array(
            self.G_X(img_data_y[0].view(1, 3, self.img_width,
                                        self.img_height).to(
                                            self.device)).cpu().detach())[0]

        def prepare_image(img):
            img = img.transpose((1, 2, 0))
            return img / 2 + 0.5

        original_x = prepare_image(original_x)
        generated_y = prepare_image(generated_y)

        original_y = prepare_image(original_y)
        generated_x = prepare_image(generated_x)

        plt.imsave('./outputs/samples/original_X_{}.png'.format(epoch),
                   original_x)
        plt.imsave('./outputs/samples/original_Y_{}.png'.format(epoch),
                   original_y)

        plt.imsave('./outputs/samples/generated_X_{}.png'.format(epoch),
                   generated_x)
        plt.imsave('./outputs/samples/generated_Y_{}.png'.format(epoch),
                   generated_y)

    # Training loop
    def train(self):
        D_X_losses = []
        D_Y_losses = []

        G_X_losses = []
        G_Y_losses = []

        for epoch in range(self.epochs):
            print("======")
            print("Epoch {}!".format(epoch + 1))

            # Track progress
            if epoch % 5 == 0:
                self.save_samples(epoch)

            # Paper reduces lr after 100 epochs
            if epoch > 100:
                self.scheduler_g.step()

            for (data_X, _), (data_Y, _) in zip(self.X_loader, self.Y_loader):
                data_X = data_X.to(self.device)
                data_Y = data_Y.to(self.device)

                # =====================================
                # Train Discriminators
                # =====================================

                # Train fake X
                self.reset_gradients()
                fake_X = self.G_X(data_Y)
                out_fake_X = self.D_X(fake_X)
                d_x_f_loss = torch.mean(out_fake_X**2)
                d_x_f_loss.backward()
                self.d_optimizer.step()

                # Train fake Y
                self.reset_gradients()
                fake_Y = self.G_Y(data_X)
                out_fake_Y = self.D_Y(fake_Y)
                d_y_f_loss = torch.mean(out_fake_Y**2)
                d_y_f_loss.backward()
                self.d_optimizer.step()

                # Train true X
                self.reset_gradients()
                out_true_X = self.D_X(data_X)
                d_x_t_loss = torch.mean((out_true_X - 1)**2)
                d_x_t_loss.backward()
                self.d_optimizer.step()

                # Train true Y
                self.reset_gradients()
                out_true_Y = self.D_Y(data_Y)
                d_y_t_loss = torch.mean((out_true_Y - 1)**2)
                d_y_t_loss.backward()
                self.d_optimizer.step()

                D_X_losses.append([
                    d_x_t_loss.cpu().detach().numpy(),
                    d_x_f_loss.cpu().detach().numpy()
                ])
                D_Y_losses.append([
                    d_y_t_loss.cpu().detach().numpy(),
                    d_y_f_loss.cpu().detach().numpy()
                ])

                # =====================================
                # Train GENERATORS
                # =====================================

                # Cycle X -> Y -> X
                self.reset_gradients()

                fake_Y = self.G_Y(data_X)
                out_fake_Y = self.D_Y(fake_Y)

                g_loss1 = torch.mean((out_fake_Y - 1)**2)
                if self.use_cycle_loss:
                    reconst_X = self.G_X(fake_Y)
                    g_loss2 = self.cycle_multiplier * torch.mean(
                        (data_X - reconst_X)**2)

                G_Y_losses.append([
                    g_loss1.cpu().detach().numpy(),
                    g_loss2.cpu().detach().numpy()
                ])
                g_loss = g_loss1 + g_loss2
                g_loss.backward()
                self.g_optimizer.step()

                # Cycle Y -> X -> Y
                self.reset_gradients()

                fake_X = self.G_X(data_Y)
                out_fake_X = self.D_X(fake_X)

                g_loss1 = torch.mean((out_fake_X - 1)**2)
                if self.use_cycle_loss:
                    reconst_Y = self.G_Y(fake_X)
                    g_loss2 = self.cycle_multiplier * torch.mean(
                        (data_Y - reconst_Y)**2)

                G_X_losses.append([
                    g_loss1.cpu().detach().numpy(),
                    g_loss2.cpu().detach().numpy()
                ])
                g_loss = g_loss1 + g_loss2
                g_loss.backward()
                self.g_optimizer.step()

                # =====================================
                # Train image IDENTITY
                # =====================================

                if self.use_identity_loss:
                    self.reset_gradients()

                    # X should be same after G(X)
                    same_X = self.G_X(data_X)
                    g_loss = self.identity_multiplier * torch.mean(
                        (data_X - same_X)**2)
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Y should be same after G(Y)
                    same_Y = self.G_X(data_Y)
                    g_loss = self.identity_multiplier * torch.mean(
                        (data_Y - same_Y)**2)
                    g_loss.backward()
                    self.g_optimizer.step()

            # Epoch done, save models
            self.save_models()

        # Save losses for analysis
        np.save(self.output_path + 'losses/G_X_losses.npy',
                np.array(G_X_losses))
        np.save(self.output_path + 'losses/G_Y_losses.npy',
                np.array(G_Y_losses))
        np.save(self.output_path + 'losses/D_X_losses.npy',
                np.array(D_X_losses))
        np.save(self.output_path + 'losses/D_Y_losses.npy',
                np.array(D_Y_losses))
示例#12
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=400,
                        help='number of epochs of training')
    parser.add_argument('--batchSize',
                        type=int,
                        default=10,
                        help='size of the batches')
    parser.add_argument('--dataroot',
                        type=str,
                        default='datasets/genderchange/',
                        help='root directory of the dataset')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument(
        '--decay_epoch',
        type=int,
        default=100,
        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size',
                        type=int,
                        default=256,
                        help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc',
                        type=int,
                        default=3,
                        help='number of channels of input data')
    parser.add_argument('--output_nc',
                        type=int,
                        default=3,
                        help='number of channels of output data')
    parser.add_argument('--cuda',
                        action='store_true',
                        help='use GPU computation')
    parser.add_argument(
        '--n_cpu',
        type=int,
        default=8,
        help='number of cpu threads to use during batch generation')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    ###### Definition of variables ######
    # Networks
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    if opt.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Dataset loader
    transforms_ = [
        transforms.Resize(int(opt.size * 1.2), Image.BICUBIC),
        transforms.CenterCrop(opt.size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=opt.n_cpu,
                            drop_last=True)

    # Plot Loss and Images in Tensorboard
    experiment_dir = 'logs/{}@{}'.format(
        opt.dataroot.split('/')[1],
        datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
    os.makedirs(experiment_dir, exist_ok=True)
    writer = SummaryWriter(os.path.join(experiment_dir, "tb"))

    metric_dict = defaultdict(list)
    n_iters_total = 0

    ###################################
    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(
                same_B, real_B) * 5.0  # [batchSize, 3, ImgSize, ImgSize]

            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(
                same_A, real_A) * 5.0  # [batchSize, 3, ImgSize, ImgSize]

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B).view(-1)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)  # [batchSize]

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A).view(-1)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)  # [batchSize]

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(
                recovered_A, real_A) * 10.0  # [batchSize, 3, ImgSize, ImgSize]

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(
                recovered_B, real_B) * 10.0  # [batchSize, 3, ImgSize, ImgSize]

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

            loss_G.backward()
            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A).view(-1)
            loss_D_real = criterion_GAN(pred_real, target_real)  # [batchSize]

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach()).view(-1)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)  # [batchSize]

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B).view(-1)
            loss_D_real = criterion_GAN(pred_real, target_real)  # [batchSize]

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach()).view(-1)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)  # [batchSize]

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            metric_dict['loss_G'].append(loss_G.item())
            metric_dict['loss_G_identity'].append(loss_identity_A.item() +
                                                  loss_identity_B.item())
            metric_dict['loss_G_GAN'].append(loss_GAN_A2B.item() +
                                             loss_GAN_B2A.item())
            metric_dict['loss_G_cycle'].append(loss_cycle_ABA.item() +
                                               loss_cycle_BAB.item())
            metric_dict['loss_D'].append(loss_D_A.item() + loss_D_B.item())

            for title, value in metric_dict.items():
                writer.add_scalar('train/{}'.format(title), value[-1],
                                  n_iters_total)

            n_iters_total += 1

        print("""
        -----------------------------------------------------------
        Epoch : {} Finished
        Loss_G : {}
        Loss_G_identity : {}
        Loss_G_GAN : {}
        Loss_G_cycle : {}
        Loss_D : {}
        -----------------------------------------------------------
        """.format(epoch, loss_G, loss_identity_A + loss_identity_B,
                   loss_GAN_A2B + loss_GAN_B2A,
                   loss_cycle_ABA + loss_cycle_BAB, loss_D_A + loss_D_B))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints

        if loss_G.item() < 2.5:
            os.makedirs(os.path.join(experiment_dir, str(epoch)),
                        exist_ok=True)
            torch.save(netG_A2B.state_dict(),
                       '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch))
            torch.save(netG_B2A.state_dict(),
                       '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch))
            torch.save(netD_A.state_dict(),
                       '{}/{}/netD_A.pth'.format(experiment_dir, epoch))
            torch.save(netD_B.state_dict(),
                       '{}/{}/netD_B.pth'.format(experiment_dir, epoch))
        elif epoch > 100 and epoch % 40 == 0:
            os.makedirs(os.path.join(experiment_dir, str(epoch)),
                        exist_ok=True)
            torch.save(netG_A2B.state_dict(),
                       '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch))
            torch.save(netG_B2A.state_dict(),
                       '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch))
            torch.save(netD_A.state_dict(),
                       '{}/{}/netD_A.pth'.format(experiment_dir, epoch))
            torch.save(netD_B.state_dict(),
                       '{}/{}/netD_B.pth'.format(experiment_dir, epoch))

        for title, value in metric_dict.items():
            writer.add_scalar("train/{}_epoch".format(title), np.mean(value),
                              epoch)
示例#13
0
def main(args):
    with open(args.params, "r") as f:
        params = json.load(f)

    generator = Generator(params["dim_latent"])
    discriminator = Discriminator()

    if args.device is not None:
        generator = generator.cuda(args.device)
        discriminator = discriminator.cuda(args.device)

    # dataloading
    train_dataset = datasets.MNIST(root=args.datadir,
                                   transform=transforms.ToTensor(),
                                   train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=params["batch_size"],
                              num_workers=4,
                              shuffle=True)

    # optimizer
    betas = (params["beta_1"], params["beta_2"])
    optimizer_G = optim.Adam(generator.parameters(),
                             lr=params["learning_rate"],
                             betas=betas)
    optimizer_D = optim.Adam(discriminator.parameters(),
                             lr=params["learning_rate"],
                             betas=betas)

    if not os.path.exists(args.modeldir): os.mkdir(args.modeldir)
    if not os.path.exists(args.logdir): os.mkdir(args.logdir)
    writer = SummaryWriter(args.logdir)

    steps_per_epoch = len(train_loader)

    msg = ["\t{0}: {1}".format(key, val) for key, val in params.items()]
    print("hyperparameters: \n" + "\n".join(msg))

    # main training loop
    for n in range(params["num_epochs"]):
        loader = iter(train_loader)

        print("epoch: {0}/{1}".format(n + 1, params["num_epochs"]))
        for i in tqdm.trange(steps_per_epoch):
            batch, _ = next(loader)
            if args.device is not None: batch = batch.cuda(args.device)

            loss_D = update_discriminator(batch, discriminator, generator,
                                          optimizer_D, params)
            loss_G = update_generator(discriminator, generator, optimizer_G,
                                      params, args.device)

            writer.add_scalar("loss_discriminator/train", loss_D,
                              i + n * steps_per_epoch)
            writer.add_scalar("loss_generator/train", loss_G,
                              i + n * steps_per_epoch)

        torch.save(generator.state_dict(),
                   args.o + ".generator." + str(n) + ".tmp")
        torch.save(discriminator.state_dict(),
                   args.o + ".discriminator." + str(n) + ".tmp")

        # eval
        with torch.no_grad():
            latent = torch.randn(args.num_fake_samples_eval,
                                 params["dim_latent"]).cuda()
            imgs_fake = generator(latent)

            writer.add_images("generated fake images", imgs_fake, n)
            del latent, imgs_fake

    writer.close()

    torch.save(generator.state_dict(), args.o + ".generator.pt")
    torch.save(discriminator.state_dict(), args.o + ".discriminator.pt")
示例#14
0
文件: ei.py 项目: edongdongchen/EI
    def train_ei_adv(self,
                     dataloader,
                     physics,
                     transform,
                     epochs,
                     lr,
                     alpha,
                     ckp_interval,
                     schedule,
                     residual=True,
                     pretrained=None,
                     task='',
                     loss_type='l2',
                     cat=True,
                     report_psnr=False,
                     lr_cos=False):
        save_path = './ckp/{}_ei_adv_{}'.format(get_timestamp(), task)

        os.makedirs(save_path, exist_ok=True)

        generator = UNet(in_channels=self.in_channels,
                         out_channels=self.out_channels,
                         compact=4,
                         residual=residual,
                         circular_padding=True,
                         cat=cat)

        if pretrained:
            checkpoint = torch.load(pretrained)
            generator.load_state_dict(checkpoint['state_dict'])

        discriminator = Discriminator(
            (self.in_channels, self.img_width, self.img_height))

        generator = generator.to(self.device)
        discriminator = discriminator.to(self.device)

        if loss_type == 'l2':
            criterion_mc = torch.nn.MSELoss().to(self.device)
            criterion_ei = torch.nn.MSELoss().to(self.device)
        if loss_type == 'l1':
            criterion_mc = torch.nn.L1Loss().to(self.device)
            criterion_ei = torch.nn.L1Loss().to(self.device)

        criterion_gan = torch.nn.MSELoss().to(self.device)

        optimizer_G = Adam(generator.parameters(),
                           lr=lr['G'],
                           weight_decay=lr['WD'])
        optimizer_D = Adam(discriminator.parameters(),
                           lr=lr['D'],
                           weight_decay=0)

        if report_psnr:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=[
                          'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G',
                          'loss_D', 'psnr', 'mse'
                      ])
        else:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=[
                          'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G',
                          'loss_D'
                      ])

        for epoch in range(epochs):
            adjust_learning_rate(optimizer_G, epoch, lr['G'], lr_cos, epochs,
                                 schedule)
            adjust_learning_rate(optimizer_D, epoch, lr['D'], lr_cos, epochs,
                                 schedule)

            loss = closure_ei_adv(generator, discriminator, dataloader,
                                  physics, transform, optimizer_G, optimizer_D,
                                  criterion_mc, criterion_ei, criterion_gan,
                                  alpha, self.dtype, self.device, report_psnr)

            log.record(epoch + 1, *loss)

            if report_psnr:
                print(
                    '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}\tpsnr={:.4f}\tmse={:.4e}'
                    .format(get_timestamp(), epoch, epochs, *loss))
            else:
                print(
                    '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}'
                    .format(get_timestamp(), epoch, epochs, *loss))

            if epoch % ckp_interval == 0 or epoch + 1 == epochs:
                state = {
                    'epoch': epoch,
                    'state_dict_G': generator.state_dict(),
                    'state_dict_D': discriminator.state_dict(),
                    'optimizer_G': optimizer_G.state_dict(),
                    'optimizer_D': optimizer_D.state_dict()
                }
                torch.save(
                    state,
                    os.path.join(save_path, 'ckp_{}.pth.tar'.format(epoch)))
        log.close()
示例#15
0
def train(args):

    # check if results path exists, if not create the folder
    check_folder(args.results_path)

    # generator model
    generator = HourglassNet(high_res=args.high_resolution)
    generator.to(device)

    # discriminator model
    discriminator = Discriminator(input_nc=1)
    discriminator.to(device)

    # optimizer
    optimizer_g = torch.optim.Adam(generator.parameters())
    optimizer_d = torch.optim.Adam(discriminator.parameters())

    # training parameters
    feature_weight = 0.5
    skip_count = 0
    use_gan = args.use_gan
    print_frequency = 5

    # dataloader
    illum_dataset = IlluminationDataset()
    illum_dataloader = DataLoader(illum_dataset, batch_size=args.batch_size)

    # gan loss based on lsgan that uses squared error
    gan_loss = GANLoss(gan_mode='lsgan')

    # training
    for epoch in range(1, args.epochs + 1):

        for data_idx, data in enumerate(illum_dataloader):
            source_img, source_light, target_img, target_light = data

            source_img.to(device)
            source_light.to(device)
            target_img.to(device)
            target_light.to(device)

            optimizer_g.zero_grad()

            # if skip connections are required for training, else skip the
            # connections based on the the training scheme for low-res/high-res
            # images
            if args.use_skip:
                skip_count = 0
            else:
                skip_count = 5 if args.high_resolution else 4

            output = generator(source_img, target_light, skip_count,
                               target_img)

            source_face_feats, source_light_pred, target_face_feats, source_relit_pred = output

            img_loss = image_and_light_loss(source_relit_pred, target_img,
                                            source_light_pred, target_light)
            feat_loss = feature_loss(source_face_feats, target_face_feats)

            # if gan loss is used
            if use_gan:
                g_loss = gan_loss(discriminator(source_relit_pred),
                                  target_is_real=True)
            else:
                g_loss = torch.Tensor([0])

            total_g_loss = img_loss + g_loss + (feature_weight * feat_loss)
            total_g_loss.backward()
            optimizer_g.step()

            # training the discriminator
            if use_gan:
                optimizer_d.zero_grad()
                pred_real = discriminator(target_img)
                pred_fake = discriminator(source_relit_pred.detach())

                loss_real = gan_loss(pred_real, target_is_real=True)
                loss_fake = gan_loss(pred_fake, target_is_real=False)

                d_loss = (loss_real + loss_fake) * 0.5
                d_loss.backward()
                optimizer_d.step()
            else:
                loss_real = torch.Tensor([0])
                loss_fake = torch.Tensor([0])

            if data_idx % print_frequency == 0:
                print(
                    "Epoch: [{}]/[{}], Iteration: [{}]/[{}], image loss: {}, feature loss: {}, gen fake loss: {}, dis real loss: {}, dis fake loss: {}"
                    .format(epoch, args.epochs + 1, data_idx + 1,
                            len(illum_dataloader), img_loss.item(),
                            feat_loss.item(), g_loss.item(), loss_real.item(),
                            loss_fake.item()))

        # saving model
        checkpoint_path = os.path.join(args.results_path,
                                       'checkpoint_epoch_{}.pth'.format(epoch))
        checkpoint = {
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'optimizer_g': optimizer_g.state_dict(),
            'optimizer_d': optimizer_d.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)