Exemplo n.º 1
0
 def load_model(self, item_embedding):
     self.model = Model(n_head=self.config.N_HEAD,
                        n_hid=item_embedding.shape[1],
                        n_seq=self.config.MAX_N_SEQ,
                        n_layer=self.config.N_LAYER,
                        item2vec=item_embedding).cuda()
     torch_utils.clip_grad_norm_(self.model.parameters(), 5)
     self.optimizer = torch.optim.Adam(self.model.parameters(),
                                       lr=self.config.LR,
                                       eps=self.config.EPS)
     self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
         self.optimizer,
         lr_lambda=LambdaLR(
             self.config.MAX_DECAY_STEP,
             self.config.DECAY_STEP).step)  # MAX_DECAY_STEP > DECAY_STEP
     self.warmup_scheduler = warmup.LinearWarmup(
         self.optimizer, warmup_period=self.config.WARMUP_PERIOD)
def _train_initialize_variables(model_str, model_params, opt_params, cuda):
    """Helper function that just initializes everything at the beginning of the train function"""
    # Params passed in as dict to model.
    model = eval(model_str)(model_params)
    model.train()  # important!

    optimizer = init_optimizer(opt_params, model)
    criterion = get_criterion(model_str)

    if opt_params['lr_scheduler'] is not None:
        if opt_params['lr_scheduler'] == 'plateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=.5, patience=1, threshold=1e-3)
        elif opt_params['lr_scheduler'] == 'delayedexpo':
            scheduler = LambdaLR(optimizer, lr_lambda=[lambda epoch: float(epoch<=4) + float(epoch>4)*1.2**(-epoch)])
        else:
            raise NotImplementedError('only plateau scheduler has been implemented so far')
    else:
        scheduler = None

    if cuda:
        model = model.cuda()
        if 'VAE' in model_str:
            model.is_cuda = True
    return model, criterion, optimizer, scheduler
Exemplo n.º 3
0
    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=args.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    target_real = torch.ones(args.batch_size,
                             dtype=torch.float).unsqueeze(1).to(device)
    target_fake = torch.ones(args.batch_size,
                             dtype=torch.float).unsqueeze(1).to(device)

    wandb_step = 0
    log_image_step = 50
Exemplo n.º 4
0
disc_a_optimizer = torch.optim.Adam(disc_a.parameters(),
                                    lr=args.lr,
                                    betas=(0.5, 0.999))
disc_b_optimizer = torch.optim.Adam(disc_b.parameters(),
                                    lr=args.lr,
                                    betas=(0.5, 0.999))
gen_a_optimizer = torch.optim.Adam(gen_a.parameters(),
                                   lr=args.lr,
                                   betas=(0.5, 0.999))
gen_b_optimizer = torch.optim.Adam(gen_b.parameters(),
                                   lr=args.lr,
                                   betas=(0.5, 0.999))

disc_a_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    disc_a_optimizer,
    lr_lambda=LambdaLR(args.epochs, 0, args.constant_lr_epochs).step)
disc_b_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    disc_b_optimizer,
    lr_lambda=LambdaLR(args.epochs, 0, args.constant_lr_epochs).step)
gen_a_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    gen_a_optimizer,
    lr_lambda=LambdaLR(args.epochs, 0, args.constant_lr_epochs).step)
gen_b_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    gen_b_optimizer,
    lr_lambda=LambdaLR(args.epochs, 0, args.constant_lr_epochs).step)

a_fake_pool = ItemPool()
b_fake_pool = ItemPool()

ckpt_dir = '{}/checkpoints/{}'.format(args.root_dir, args.dataset)
mkdir(ckpt_dir)
Exemplo n.º 5
0
def main(args):
    torch.manual_seed(0)
    if args.mb_D:
        raise NotImplementedError('mb_D not implemented')
        assert args.batch_size > 1, 'batch size needs to be larger than 1 if mb_D'

    if args.img_norm != 'znorm':
        raise NotImplementedError('{} not implemented'.format(args.img_norm))

    assert args.act in ['relu', 'mish'], 'args.act = {}'.format(args.act)

    modelarch = 'C_{0}_{1}_{2}_{3}_{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}{14}{15}{16}{17}{18}{19}{20}{21}{22}'.format(
        args.size, args.batch_size, args.lr,  args.n_epochs, args.decay_epoch, # 0, 1, 2, 3, 4
        '_G' if args.G_extra else '',  # 5
        '_D' if args.D_extra else '',  # 6
        '_U' if args.upsample else '',  # 7
        '_S' if args.slow_D else '',  # 8
        '_RL{}-{}'.format(args.start_recon_loss_val, args.start_recon_loss_val),  # 9
        '_GL{}-{}'.format(args.start_gan_loss_val, args.start_gan_loss_val),  # 10
        '_prop' if args.keep_prop else '',  # 11
        '_' + args.img_norm,  # 12
        '_WL' if args.wasserstein else '',  # 13
        '_MBD' if args.mb_D else '',  # 14
        '_FM' if args.fm_loss else '',  # 15
        '_BF{}'.format(args.buffer_size) if args.buffer_size != 50 else '',  # 16
        '_N' if args.add_noise else '',  # 17
        '_L{}'.format(args.load_iter) if args.load_iter > 0 else '',  # 18
        '_res{}'.format(args.n_resnet_blocks),  # 19
        '_n{}'.format(args.data_subset) if args.data_subset is not None else '',  # 20
        '_{}'.format(args.optim),  # 21
        '_{}'.format(args.act))  # 22

    samples_path = os.path.join(args.output_dir, modelarch, 'samples')
    safe_mkdirs(samples_path)
    model_path = os.path.join(args.output_dir, modelarch, 'models')
    safe_mkdirs(model_path)
    test_path = os.path.join(args.output_dir, modelarch, 'test')
    safe_mkdirs(test_path)

    # Definition of variables ######
    # Networks
    netG_A2B = Generator(args.input_nc, args.output_nc, img_size=args.size,
                         extra_layer=args.G_extra, upsample=args.upsample,
                         keep_weights_proportional=args.keep_prop,
                         n_residual_blocks=args.n_resnet_blocks,
                         act=args.act)
    netG_B2A = Generator(args.output_nc, args.input_nc, img_size=args.size,
                         extra_layer=args.G_extra, upsample=args.upsample,
                         keep_weights_proportional=args.keep_prop,
                         n_residual_blocks=args.n_resnet_blocks,
                         act=args.act)
    netD_A = Discriminator(args.input_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size)
    netD_B = Discriminator(args.output_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size)

    if args.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    if args.load_iter == 0:
        netG_A2B.apply(weights_init_normal)
        netG_B2A.apply(weights_init_normal)
        netD_A.apply(weights_init_normal)
        netD_B.apply(weights_init_normal)
    else:
        netG_A2B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_A2B_{}.pth'.format(args.load_iter))))
        netG_B2A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_B2A_{}.pth'.format(args.load_iter))))
        netD_A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_A_{}.pth'.format(args.load_iter))))
        netD_B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_B_{}.pth'.format(args.load_iter))))

        netG_A2B.train()
        netG_B2A.train()
        netD_A.train()
        netD_B.train()

    # Lossess
    criterion_GAN = wasserstein_loss if args.wasserstein else torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    feat_criterion = torch.nn.HingeEmbeddingLoss()

    # I could also update D only if iters % 2 == 0
    lr_G = args.lr
    lr_D = args.lr / 2 if args.slow_D else args.lr

    # Optimizers & LR schedulers
    if args.optim == 'adam':
        optim = torch.optim.Adam
    elif args.optim == 'radam':
        optim = RAdam
    elif args.optim == 'ranger':
        optim = Ranger
    elif args.optim == 'rangerlars':
        optim = RangerLars
    else:
        raise NotImplementedError('args.optim = {} not implemented'.format(args.optim))

    optimizer_G = optim(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                        lr=args.lr, betas=(0.5, 0.999))
    optimizer_D_A = optim(netD_A.parameters(), lr=lr_G, betas=(0.5, 0.999))
    optimizer_D_B = optim(netD_B.parameters(), lr=lr_D, betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if args.cuda else torch.Tensor
    input_A = Tensor(args.batch_size, args.input_nc, args.size, args.size)
    input_B = Tensor(args.batch_size, args.output_nc, args.size, args.size)
    target_real = Variable(Tensor(args.batch_size).fill_(1.0), requires_grad=False)
    target_fake = Variable(Tensor(args.batch_size).fill_(0.0), requires_grad=False)

    fake_A_buffer = ReplayBuffer(args.buffer_size)
    fake_B_buffer = ReplayBuffer(args.buffer_size)

    # Transforms and dataloader for training set
    transforms_ = []
    if args.resize_crop:
        transforms_ += [transforms.Resize(int(args.size*1.12), Image.BICUBIC),
                        transforms.RandomCrop(args.size)]
    else:
        transforms_ += [transforms.Resize(args.size, Image.BICUBIC)]

    if args.horizontal_flip:
        transforms_ += [transforms.RandomHorizontalFlip()]

    transforms_ += [transforms.ToTensor()]

    if args.add_noise:
        transforms_ += [transforms.Lambda(lambda x: x + torch.randn_like(x))]

    transforms_norm = []
    if args.img_norm == 'znorm':
        transforms_norm += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    elif 'scale01' in args.img_norm:
        transforms_norm += [transforms.Lambda(lambda x: x.mul(1/255))]  # TODO this might not preserve the dimensions. is .mul per element?
        if 'flip' in args.img_norm:
            transforms_norm += [transforms.Lambda(lambda x: (x - 1).abs())]  # TODO this might not preserve the dimensions. is .mul per element?
    else:
        raise ValueError('wrong --img_norm. only znorm|scale01|scale01flip')

    transforms_ += transforms_norm

    dataloader = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_, unaligned=True, n=args.data_subset),
                            batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu)

    # Transforms and dataloader for test set
    transforms_test_ = [transforms.Resize(args.size, Image.BICUBIC),
                        transforms.ToTensor()]
    transforms_test_ += transforms_norm

    dataloader_test = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_test_, mode='test'),
                                 batch_size=args.batch_size, shuffle=False, num_workers=args.n_cpu)
    # Training ######
    if args.load_iter == 0 and args.load_epoch != 0:
        print('****** NOTE: args.load_iter == 0 and args.load_epoch != 0 ******')

    iter = args.load_iter
    prev_time = time.time()
    n_test = 10e10 if args.n_test is None else args.n_test
    n_sample = 10e10 if args.n_sample is None else args.n_sample

    rl_delta_x = args.n_epochs - args.recon_loss_epoch
    rl_delta_y = args.end_recon_loss_val - args.start_recon_loss_val

    gan_delta_x = args.n_epochs - args.gan_loss_epoch
    gan_delta_y = args.end_gan_loss_val - args.start_gan_loss_val

    for epoch in range(args.load_epoch, args.n_epochs):

        rl_effective_epoch = max(epoch - args.recon_loss_epoch, 0)
        recon_loss_rate = args.start_recon_loss_val + rl_effective_epoch * (rl_delta_y / rl_delta_x)

        gan_effective_epoch = max(epoch - args.gan_loss_epoch, 0)
        gan_loss_rate = args.start_gan_loss_val + gan_effective_epoch * (gan_delta_y / gan_delta_x)

        id_loss_rate = 5.0

        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            # Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake, _ = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

            fake_A = netG_B2A(real_B)
            pred_fake, _ = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)

            # Total loss
            loss_G = (loss_identity_A + loss_identity_B) * id_loss_rate
            loss_G += (loss_GAN_A2B + loss_GAN_B2A) * gan_loss_rate
            loss_G += (loss_cycle_ABA + loss_cycle_BAB) * recon_loss_rate

            loss_G.backward()

            optimizer_G.step()

            # Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real, _ = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake, _ = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            loss_D_A = (loss_D_real + loss_D_fake) * 0.5

            if args.fm_loss:
                pred_real, feats_real = netD_A(real_A)
                pred_fake, feats_fake = netD_A(fake_A.detach())

                fm_loss_A = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda)

                loss_D_A = loss_D_A * 0.1 + fm_loss_A * 0.9

            loss_D_A.backward()

            optimizer_D_A.step()

            # Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real, _ = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake, _ = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            loss_D_B = (loss_D_real + loss_D_fake)*0.5

            if args.fm_loss:
                pred_real, feats_real = netD_B(real_B)
                pred_fake, feats_fake = netD_B(fake_B.detach())

                fm_loss_B = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda)

                loss_D_B = loss_D_B * 0.1 + fm_loss_B * 0.9

            loss_D_B.backward()

            optimizer_D_B.step()

            if iter % args.log_interval == 0:

                print('---------------------')
                print('GAN loss:', as_np(loss_GAN_A2B), as_np(loss_GAN_B2A))
                print('Identity loss:', as_np(loss_identity_A), as_np(loss_identity_B))
                print('Cycle loss:', as_np(loss_cycle_ABA), as_np(loss_cycle_BAB))
                print('D loss:', as_np(loss_D_A), as_np(loss_D_B))
                if args.fm_loss:
                    print('fm loss:', as_np(fm_loss_A), as_np(fm_loss_B))
                print('recon loss rate:', recon_loss_rate)
                print('time:', time.time() - prev_time)
                prev_time = time.time()

            if iter % args.plot_interval == 0:
                pass

            if iter % args.image_save_interval == 0:
                samples_path_ = os.path.join(samples_path, str(iter / args.image_save_interval))
                safe_mkdirs(samples_path_)

                # New savedir
                test_pth_AB = os.path.join(test_path, str(iter / args.image_save_interval), 'AB')
                test_pth_BA = os.path.join(test_path, str(iter / args.image_save_interval), 'BA')

                safe_mkdirs(test_pth_AB)
                safe_mkdirs(test_pth_BA)

                for j, batch_ in enumerate(dataloader_test):

                    real_A_test = Variable(input_A.copy_(batch_['A']))
                    real_B_test = Variable(input_B.copy_(batch_['B']))

                    fake_AB_test = netG_A2B(real_A_test)
                    fake_BA_test = netG_B2A(real_B_test)

                    if j < n_sample:
                        recovered_ABA_test = netG_B2A(fake_AB_test)
                        recovered_BAB_test = netG_A2B(fake_BA_test)

                        fn = os.path.join(samples_path_, str(j))
                        imageio.imwrite(fn + '.A.jpg', tensor2image(real_A_test[0], args.img_norm))
                        imageio.imwrite(fn + '.B.jpg', tensor2image(real_B_test[0], args.img_norm))
                        imageio.imwrite(fn + '.BA.jpg', tensor2image(fake_BA_test[0], args.img_norm))
                        imageio.imwrite(fn + '.AB.jpg', tensor2image(fake_AB_test[0], args.img_norm))
                        imageio.imwrite(fn + '.ABA.jpg', tensor2image(recovered_ABA_test[0], args.img_norm))
                        imageio.imwrite(fn + '.BAB.jpg', tensor2image(recovered_BAB_test[0], args.img_norm))

                    if j < n_test:
                        fn_A = os.path.basename(batch_['img_A'][0])
                        imageio.imwrite(os.path.join(test_pth_AB, fn_A), tensor2image(fake_AB_test[0], args.img_norm))

                        fn_B = os.path.basename(batch_['img_B'][0])
                        imageio.imwrite(os.path.join(test_pth_BA, fn_B), tensor2image(fake_BA_test[0], args.img_norm))

            if iter % args.model_save_interval == 0:
                # Save models checkpoints
                torch.save(netG_A2B.state_dict(), os.path.join(model_path, 'G_A2B_{}.pth'.format(iter)))
                torch.save(netG_B2A.state_dict(), os.path.join(model_path, 'G_B2A_{}.pth'.format(iter)))
                torch.save(netD_A.state_dict(), os.path.join(model_path, 'D_A_{}.pth'.format(iter)))
                torch.save(netD_B.state_dict(), os.path.join(model_path, 'D_B_{}.pth'.format(iter)))

            iter += 1

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
Exemplo n.º 6
0
netD_A.load_state_dict(pretrained_dict)
pretrained_dict = torch.load('/net/cremi/smjoshi/espaces/travail/barcelona/PyTorch-CycleGAN/checkpoints/netD_B.pth')
netD_B.load_state_dict(pretrained_dict)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                lr=opt['lr'], betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt['lr'], betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt['lr'], betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt['n_epochs'], opt['epoch'], opt['decay_epoch']).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt['n_epochs'], opt['epoch'], opt['decay_epoch']).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt['n_epochs'],opt['epoch'], opt['decay_epoch']).step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt['cuda'] else torch.Tensor
input_A = Tensor(opt['batch_size'], opt['input_nc'], opt['size'], opt['size'])
input_B = Tensor(opt['batch_size'], opt['output_nc'], opt['size'], opt['size'])
target_real = Variable(Tensor(opt['batch_size']).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(opt['batch_size']).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
Exemplo n.º 7
0
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
target_real = Variable(Tensor(opt.batchSize,1).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(opt.batchSize,1).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
Exemplo n.º 8
0
# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                               netG_B2A.parameters()),
                               lr=opt.lr,
                               betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=LambdaLR(
                                                       opt.n_epochs, opt.epoch,
                                                       opt.decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A,
    lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B,
    lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_A = Tensor(opt.batch_size, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batch_size, opt.output_nc, opt.size, opt.size)
target_real = Variable(Tensor(opt.batch_size, 1).fill_(1.0),
                       requires_grad=False)
target_fake = Variable(Tensor(opt.batch_size, 1).fill_(0.0),
# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(encoder.parameters(),
                                               decoder_A2B.parameters(),
                                               decoder_B2A.parameters()),
                               lr=lr,
                               betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                 lr=lr,
                                 betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                 lr=lr,
                                 betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=LambdaLR(
                                                       n_epochs, start_epoch,
                                                       decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A,
                                                     lr_lambda=LambdaLR(
                                                         n_epochs, start_epoch,
                                                         decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B,
                                                     lr_lambda=LambdaLR(
                                                         n_epochs, start_epoch,
                                                         decay_epoch).step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if activate_cuda else torch.Tensor
input_A = Tensor(batch_size, input_nc, image_size, image_size)
input_B = Tensor(batch_size, output_nc, image_size, image_size)
target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
Exemplo n.º 10
0
netG_en2zh = Generator(3,3).to(device)
netG_zh2en = Generator(3,3).to(device)
netD_en = Discriminator(3).to(device)
netD_zh = Discriminator(3).to(device)

netG_en2zh.apply(weights_init_normal)
netG_zh2en.apply(weights_init_normal)
netD_en.apply(weights_init_normal)
netD_zh.apply(weights_init_normal)

# optimizers and learning rate schedulers
optimizer_G = Adam(itertools.chain(netG_en2zh.parameters(), netG_zh2en.parameters()), lr=opt.lr, betas=BETAS)
optimizer_D_en = Adam(netD_en.parameters(), lr=opt.lr, betas=BETAS)
optimizer_D_zh = Adam(netD_zh.parameters(), lr=opt.lr, betas=BETAS)

lr_scheduler_G = lr_scheduler.LambdaLR(optimizer_G, lr_lambda = LambdaLR(opt.n_epochs,0,DECAY_EPOCH).step)
lr_scheduler_D_en = lr_scheduler.LambdaLR(optimizer_D_en, lr_lambda = LambdaLR(opt.n_epochs,0,DECAY_EPOCH).step)
lr_scheduler_D_zh = lr_scheduler.LambdaLR(optimizer_D_zh, lr_lambda = LambdaLR(opt.n_epochs,0,DECAY_EPOCH).step)

def train():
    for epoch in range(opt.n_epochs):
        print('=== Starting epoch:', epoch, '===')
        lr_scheduler_G.step()
        lr_scheduler_D_en.step()
        lr_scheduler_D_zh.step()

        for index, data in enumerate(dataloader):
            real_data_en = data['en'].to(device)
            real_data_zh = data['zh'].to(device)

            ###################
Exemplo n.º 11
0

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_l1 = torch.nn.L1Loss()
criterion_feat = torch.nn.MSELoss()
criterion_VGG= VGGLoss()

# Optimizers & LR schedulers
optimizer_encoder = torch.optim.Adam(encoder.parameters(),lr=opt.lr, betas=(0.5, 0.999))
optimizer_decoder = torch.optim.Adam(decoder.parameters(),lr=opt.lr, betas=(0.5, 0.999))

optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(0.5, 0.999))
# optimizer_t = torch.optim.Adam(transformer.parameters(), lr=opt.lr, betas=(0.5, 0.999))

lr_scheduler_encoder = torch.optim.lr_scheduler.LambdaLR(optimizer_encoder, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_decoder = torch.optim.lr_scheduler.LambdaLR(optimizer_decoder, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
# lr_scheduler_t = torch.optim.lr_scheduler.LambdaLR(optimizer_t, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
# input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
# input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
# target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
# target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)

fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
Exemplo n.º 12
0
def main():
    cuda = torch.cuda.is_available()

    input_shape = (opt.channels, opt.img_height, opt.img_width)

    # Initialize generator and discriminator
    G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
    G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
    D_A = Discriminator(input_shape)
    D_B = Discriminator(input_shape)

    if cuda:
        G_AB = G_AB.cuda()
        G_BA = G_BA.cuda()
        D_A = D_A.cuda()
        D_B = D_B.cuda()
        criterion_GAN.cuda()
        criterion_cycle.cuda()
        criterion_identity.cuda()

    if opt.epoch != 0:
        # Load pretrained models
        G_AB.load_state_dict(
            torch.load("saved_models/%s/G_AB_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        G_BA.load_state_dict(
            torch.load("saved_models/%s/G_BA_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_A.load_state_dict(
            torch.load("saved_models/%s/D_A_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_B.load_state_dict(
            torch.load("saved_models/%s/D_B_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

    # Buffers of previously generated samples
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Image transformations
    transforms_ = [
        transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opt.img_height, opt.img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    # Training data loader
    dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )
    # Test data loader
    val_dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True,
                     mode="test"),
        batch_size=5,
        shuffle=True,
        num_workers=1,
    )

    def sample_images(batches_done):
        """Saves a generated sample from the test set"""
        imgs = next(iter(val_dataloader))
        G_AB.eval()
        G_BA.eval()
        real_A = Variable(imgs["A"].type(Tensor))
        fake_B = G_AB(real_A)
        real_B = Variable(imgs["B"].type(Tensor))
        fake_A = G_BA(real_B)
        # Arange images along x-axis
        real_A = make_grid(real_A, nrow=5, normalize=True)
        real_B = make_grid(real_B, nrow=5, normalize=True)
        fake_A = make_grid(fake_A, nrow=5, normalize=True)
        fake_B = make_grid(fake_B, nrow=5, normalize=True)
        # Arange images along y-axis
        image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
        save_image(image_grid,
                   "images/%s/%s.png" % (opt.dataset_name, batches_done),
                   normalize=False)

    # ----------
    #  Training
    # ----------
    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(
                np.ones((real_A.size(0), *D_A.output_shape))),
                             requires_grad=False)
            fake = Variable(Tensor(
                np.zeros((real_A.size(0), *D_A.output_shape))),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------

            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_GAN.item(),
                    loss_cycle.item(),
                    loss_identity.item(),
                    time_left,
                ))

            # If at sample interval save image
            if batches_done % opt.sample_interval == 0:
                sample_images(batches_done)

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(
                G_AB.state_dict(),
                "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                G_BA.state_dict(),
                "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_A.state_dict(),
                "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_B.state_dict(),
                "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
Exemplo n.º 13
0
        D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (dataset_name, epoch)))
    else:
        # Initialize weights
        init_weights_of_model(G_AB, init_type=init_type, init_gain=init_gain)
        init_weights_of_model(G_BA, init_type=init_type, init_gain=init_gain)
        init_weights_of_model(D_A, init_type=init_type, init_gain=init_gain)
        init_weights_of_model(D_B, init_type=init_type, init_gain=init_gain)

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
    )
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
    )
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
    )

    Tensor = torch.cuda.FloatTensor if CUDA else torch.Tensor

    # Buffers of previously generated samples
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Image transformations
    def __init__(self,
                 epoch=0,
                 n_epochs=1000,
                 batchSize=1,
                 lr=0.0002,
                 decay_epoch=100,
                 size=256,
                 input_nc=3,
                 output_nc=3,
                 cuda=True,
                 n_cpu=8,
                 load_from_ckpt=False):

        self.epoch = epoch
        self.n_epochs = n_epochs
        self.batchSize = batchSize
        self.lr = lr
        self.decay_epoch = decay_epoch
        self.size = size
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.cuda = cuda
        self.n_cpu = n_cpu

        rootA = "../dataset/monet_field_data"
        rootB = "../dataset/field_data"

        if torch.cuda.is_available() and not self.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )

        ###### Definition of variables ######
        # Networks
        self.netG_A2B = Generator(self.input_nc, self.output_nc)
        self.netG_B2A = Generator(self.output_nc, self.input_nc)
        self.netD_A = Discriminator(self.input_nc)
        self.netD_B = Discriminator(self.output_nc)

        if load_from_ckpt:
            print("loading from ckpt")
            self.netG_A2B.load_state_dict(torch.load('output/netG_A2B.pth'))
            self.netG_B2A.load_state_dict(torch.load('output/netG_B2A.pth'))
            self.netD_A.load_state_dict(torch.load('output/netD_A.pth'))
            self.netD_B.load_state_dict(torch.load('output/netD_B.pth'))
        else:
            self.netG_A2B.apply(weights_init_normal)
            self.netG_B2A.apply(weights_init_normal)
            self.netD_A.apply(weights_init_normal)
            self.netD_B.apply(weights_init_normal)

        if self.cuda:
            self.netG_A2B.cuda()
            self.netG_B2A.cuda()
            self.netD_A.cuda()
            self.netD_B.cuda()

        # Lossess
        self.criterion_GAN = torch.nn.MSELoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()

        # Optimizers & LR schedulers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A2B.parameters(), self.netG_B2A.parameters()),
                                            lr=self.lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                              lr=self.lr,
                                              betas=(0.5, 0.999))
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                              lr=self.lr,
                                              betas=(0.5, 0.999))

        self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=LambdaLR(self.n_epochs, self.epoch,
                               self.decay_epoch).step)
        self.lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D_A,
            lr_lambda=LambdaLR(self.n_epochs, self.epoch,
                               self.decay_epoch).step)
        self.lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D_B,
            lr_lambda=LambdaLR(self.n_epochs, self.epoch,
                               self.decay_epoch).step)

        if load_from_ckpt:
            print('load states')
            checkpoint = torch.load('output/states.pth')
            '''
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A'])
            self.optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B'])
            
            self.lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
            self.lr_scheduler_D_A.load_state_dict(checkpoint['lr_scheduler_D_A'])
            self.lr_scheduler_D_B.load_state_dict(checkpoint['lr_scheduler_D_B'])
            '''

            self.lr = checkpoint['lr']
            self.epoch = checkpoint['epoch'] + 1

        # Inputs & targets memory allocation
        Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor
        self.input_A = Tensor(self.batchSize, self.input_nc, self.size,
                              self.size)
        self.input_B = Tensor(self.batchSize, self.output_nc, self.size,
                              self.size)
        self.target_real = Variable(Tensor(self.batchSize).fill_(1.0),
                                    requires_grad=False)
        self.target_fake = Variable(Tensor(self.batchSize).fill_(0.0),
                                    requires_grad=False)

        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()

        # Dataset loader
        transforms_ = [
            transforms.Resize((int(self.size * 1.12), int(self.size * 1.12)),
                              Image.BICUBIC),
            transforms.RandomCrop(self.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        self.dataloader = DataLoader(ImageDataset(rootA,
                                                  rootB,
                                                  transforms_=transforms_,
                                                  unaligned=True),
                                     batch_size=self.batchSize,
                                     shuffle=True,
                                     num_workers=self.n_cpu)
Exemplo n.º 15
0
def train():

    G_AB = Generator(input_nc, output_nc)
    G_BA = Generator(output_nc, input_nc)

    D_A = Discriminator(input_nc)
    D_B = Discriminator(output_nc)

    G_AB.cuda()
    G_BA.cuda()
    D_A.cuda()
    D_B.cuda()

    G_AB.apply(weights_init)
    G_BA.apply(weights_init)
    D_A.apply(weights_init)
    D_B.apply(weights_init)

    #Loss
    GD_loss = nn.MSELoss()
    L1_loss = nn.L1Loss()
    L1_loss_identity = nn.L1Loss()

    optim_G = optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()),
                         lr=lr_G,
                         betas=(0.5, 0.999))
    optim_D_A = optim.Adam(D_A.parameters(), lr=lr_D, betas=(0.5, 0.999))
    optim_D_B = optim.Adam(D_B.parameters(), lr=lr_D, betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optim_G, lr_lambda=LambdaLR(n_epochs, start_epoch, decay).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optim_D_A, lr_lambda=LambdaLR(n_epochs, start_epoch, decay).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optim_D_B, lr_lambda=LambdaLR(n_epochs, start_epoch, decay).step)

    Tensor = torch.cuda.FloatTensor
    input_A = Tensor(1, 3, 256, 256)
    input_B = Tensor(1, 3, 256, 256)

    fake_A_buffer = keep()
    fake_B_buffer = keep()

    if opt.opencv:
        print('OPENCV MODE')
        transforms_ = [
            T.Scale(286),
            T.RandomCrop(256),
            T.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    else:
        print('PIL MODE')
        transforms_ = [
            transforms.Resize(286, Image.BICUBIC),
            transforms.RandomCrop(256),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

    dataloader = DataLoader(Loadimage(opt.dataroot,
                                      transforms_=transforms_,
                                      unaligned=True,
                                      mode_opencv=opt.opencv),
                            batch_size=1,
                            shuffle=True,
                            num_workers=8)

    for epoch in range(start_epoch, n_epochs):
        for i, batch in enumerate(dataloader):

            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ########################################
            #Train Generator
            #A to B
            optim_G.zero_grad()

            #identity loss
            same_B = G_AB(real_B)
            loss_identity_B = L1_loss_identity(same_B, real_B) * 5.0

            same_A = G_BA(real_A)
            loss_identity_A = L1_loss_identity(same_A, real_A) * 5.0

            fake_B = G_AB(real_A)
            pred_fake_B = D_B(fake_B)
            G_AB_Loss = GD_loss(
                pred_fake_B, Variable(torch.ones(pred_fake_B.size()).cuda()))

            #B to A
            fake_A = G_BA(real_B)
            pred_fake_A = D_A(fake_A)
            G_BA_Loss = GD_loss(
                pred_fake_A, Variable(torch.ones(pred_fake_A.size()).cuda()))

            #fake B to A
            similar_A = G_BA(fake_B)
            BA_cycle_loss = L1_loss(similar_A, real_A) * 10.0

            #fake A to B
            similar_B = G_AB(fake_A)
            AB_cycle_loss = L1_loss(similar_B, real_B) * 10.0

            #total loss G
            G_loss = G_AB_Loss + G_BA_Loss + BA_cycle_loss + AB_cycle_loss + loss_identity_A + loss_identity_B

            G_loss_identity = loss_identity_B + loss_identity_A
            G_loss_GAN = G_AB_Loss + G_BA_Loss
            G_loss_cycle = BA_cycle_loss + AB_cycle_loss

            #OptimizeG
            G_loss.backward()
            optim_G.step()

            loss_G_plot.append(G_loss.data[0])
            loss_G_identity_plot.append(G_loss_identity.data[0])
            loss_G_GAN_plot.append(G_loss_GAN.data[0])
            loss_G_cycle_plot.append(G_loss_cycle.data[0])

            #######################################
            #Train Discriminator
            #Discriminator D_AB
            optim_D_A.zero_grad()

            pred_real_A = D_A(real_A)
            D_real_loss = GD_loss(
                pred_real_A, Variable(torch.ones(pred_real_A.size()).cuda()))
            fake_A = fake_A_buffer.empty_fill_data(fake_A)
            pred_d_fake_A = D_A(fake_A)
            D_fake_loss = GD_loss(
                pred_d_fake_A,
                Variable(torch.zeros(pred_d_fake_A.size()).cuda()))

            D_A_loss_total = (D_real_loss + D_fake_loss) * 0.5

            D_A_loss_total.backward()
            optim_D_A.step()

            #Discriminator D_BA
            optim_D_B.zero_grad()

            pred_real_B = D_B(real_B)
            D_real_loss = GD_loss(
                pred_real_B, Variable(torch.ones(pred_real_B.size()).cuda()))
            fake_B = fake_B_buffer.empty_fill_data(fake_B)
            pred_d_fake_B = D_B(fake_B)
            D_fake_loss = GD_loss(
                pred_d_fake_B,
                Variable(torch.zeros(pred_d_fake_B.size()).cuda()))

            D_B_loss_total = (D_real_loss + D_fake_loss) * 0.5

            D_B_loss_total.backward()
            optim_D_B.step()

            D_Loss = D_A_loss_total + D_B_loss_total

            loss_D_plot.append(D_Loss.data[0])
            #####################################

            #Print All losses
            print('Epoch [%d/%d], Step [%d/%d], G_loss: %.4f, D_B_Loss: %.4f' %
                  (epoch + 1, n_epochs, i + 1, len(dataloader), G_loss.data[0],
                   D_Loss.data[0]))

        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        torch.save(G_AB.state_dict(), 'output/G_AB.pth')
        torch.save(G_BA.state_dict(), 'output/G_BA.pth')
        torch.save(D_A.state_dict(), 'output/D_AB.pth')
        torch.save(D_B.state_dict(), 'output/D_BA.pth')

    x = np.linspace(start_epoch, n_epochs, num=len(loss_G_plot))

    plt.figure(1)
    plt.plot(x, loss_G_plot)
    plt.xticks(np.arange(start_epoch, n_epochs + 1, 25))
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss_G')
    plt.savefig('Loss_G.png')

    plt.figure(2)
    plt.plot(x, loss_G_identity_plot)
    plt.xticks(np.arange(start_epoch, n_epochs + 1, 25))
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss_G_Identity')
    plt.savefig('Loss_G_Identity.png')

    plt.figure(3)
    plt.plot(x, loss_G_GAN_plot)
    plt.xticks(np.arange(start_epoch, n_epochs + 1, 25))
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss_G_GAN')
    plt.savefig('Loss_G_GAN.png')

    plt.figure(4)
    plt.plot(x, loss_G_cycle_plot)
    plt.xticks(np.arange(start_epoch, n_epochs + 1, 25))
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss_G_Cycle')
    plt.savefig('Loss_G_Cycle.png')

    plt.figure(5)
    plt.plot(x, loss_D_plot)
    plt.xticks(np.arange(start_epoch, n_epochs + 1, 25))
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss_D')
    plt.savefig('Loss_D.png')
Exemplo n.º 16
0
def main(args):
    writer = SummaryWriter(os.path.join(args.out_dir, 'logs'))
    current_time = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
    os.makedirs(
        os.path.join(args.out_dir, 'models',
                     args.model_name + '_' + current_time))
    os.makedirs(
        os.path.join(args.out_dir, 'logs',
                     args.model_name + '_' + current_time))

    G_AB = Generator(args.in_channel, args.out_channel).to(args.device)
    G_BA = Generator(args.in_channel, args.out_channel).to(args.device)
    D_A = Discriminator(args.in_channel).to(args.device)
    D_B = Discriminator(args.out_channel).to(args.device)
    segmen_B = Unet(3, 34).to(args.device)

    if args.model_path is not None:
        AB_path = os.join.path(args.model_path, 'ab.pt')
        BA_path = os.join.path(args.model_path, 'ba.pt')
        DA_path = os.join.path(args.model_path, 'da.pt')
        DB_path = os.join.path(args.model_path, 'db.pt')
        segmen_path = os.join.path(args.model_path, 'semsg.pt')

        with open(AB_path, 'rb') as f:
            state_dict = torch.load(f)
            G_AB.load_state_dict(state_dict)

        with open(BA_path, 'rb') as f:
            state_dict = torch.load(f)
            G_BA.load_state_dict(state_dict)

        with open(DA_path, 'rb') as f:
            state_dict = torch.load(f)
            D_A.load_state_dict(state_dict)

        with open(DB_path, 'rb') as f:
            state_dict = torch.load(f)
            D_B.load_state_dict(state_dict)

        with open(segmen_path, 'rb') as f:
            state_dict = torch.load(f)
            segmen_B.load_state_dict(state_dict)

    else:
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

    G_AB = nn.DataParallel(G_AB)
    G_BA = nn.DataParallel(G_BA)
    D_A = nn.DataParallel(D_A)
    D_B = nn.DataParallel(D_B)
    segmen_B = nn.DataParallel(segmen_B)

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion_segmen = torch.nn.BCELoss()

    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=args.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))

    optimizer_segmen_B = torch.optim.Adam(segmen_B.parameters(),
                                          lr=args.lr,
                                          betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    transforms_ = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    dataloader = DataLoader(ImgDataset(args.dataset_path,
                                       transforms_=transforms_,
                                       unaligned=True,
                                       device=args.device),
                            batch_size=args.batchSize,
                            shuffle=True,
                            num_workers=0)
    logger = Logger(args.n_epochs, len(dataloader))
    target_real = Variable(torch.Tensor(args.batchSize,
                                        1).fill_(1.)).to(args.device).detach()
    target_fake = Variable(torch.Tensor(args.batchSize,
                                        1).fill_(0.)).to(args.device).detach()

    G_AB.train()
    G_BA.train()
    D_A.train()
    D_B.train()
    segmen_B.train()

    for epoch in range(args.epoch, args.n_epochs):
        for i, batch in enumerate(dataloader):
            real_A = batch['A'].clone()
            real_B = batch['B'].clone()
            B_label = batch['B_label'].clone()

            fake_b = G_AB(real_A)
            fake_a = G_BA(real_B)
            same_b = G_AB(real_B)
            same_a = G_BA(real_A)
            recovered_A = G_BA(fake_b)
            recovered_B = G_AB(fake_a)
            pred_Blabel = segmen_B(real_B)
            pred_fakeAlabel = segmen_B(fake_a)

            optimizer_segmen_B.zero_grad()
            #segmen loss, do we assume that it also learns how to segment images after doing domain transfer?
            loss_segmen_B = criterion_segmen(
                pred_Blabel, B_label) + criterion_segmen(
                    segmen_B(fake_a.detach()), B_label)
            loss_segmen_B.backward()
            optimizer_segmen_B.step()

            optimizer_G.zero_grad()
            #gan loss
            pred_fakeb = D_B(fake_b)
            loss_gan_AB = criterion_GAN(pred_fakeb, target_real)

            pred_fakea = D_A(fake_a)
            loss_gan_BA = criterion_GAN(pred_fakea, target_real)

            #identity loss
            loss_identity_B = criterion_identity(same_b, real_B) * 5
            loss_identity_A = criterion_identity(same_a, real_A) * 5

            #cycle consistency loss
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10

            #cycle segmen diff loss
            loss_segmen_diff = criterion_segmen(segmen_B(recovered_B),
                                                pred_Blabel.detach())

            loss_G = loss_gan_AB + loss_gan_BA + loss_identity_B + loss_identity_A + loss_cycle_ABA + loss_cycle_BAB + loss_segmen_diff
            loss_G.backward()

            optimizer_G.step()

            ##discriminator a
            optimizer_D_A.zero_grad()

            pred_realA = D_A(real_A)
            loss_D_A_real = criterion_GAN(pred_realA, target_real)

            fake_A = fake_A_buffer.push_and_pop(fake_a)
            pred_fakeA = D_A(fake_A.detach())
            loss_D_A_fake = criterion_GAN(pred_fakeA, target_fake)

            loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()

            #discriminator b
            optimizer_D_B.zero_grad()

            pred_realB = D_B(real_B)
            loss_D_B_real = criterion_GAN(pred_realB, target_real)

            fake_B = fake_B_buffer.push_and_pop(fake_b)
            pred_fakeB = D_B(fake_B.detach())
            loss_D_B_fake = criterion_GAN(pred_fakeB, target_fake)

            loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()

            logger.log(
                {
                    'loss_segmen_B': loss_segmen_B,
                    'loss_G': loss_G,
                    'loss_G_identity': (loss_identity_A + loss_identity_B),
                    'loss_G_GAN': (loss_gan_AB + loss_gan_BA),
                    'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB),
                    'loss_D': (loss_D_A + loss_D_B)
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_a,
                    'fake_B': fake_b,
                    'reconstructed_A': recovered_A,
                    'reconstructed_B': recovered_B
                },
                out_dir=os.path.join(
                    args.out_dir, 'logs',
                    args.model_name + '_' + current_time + '/' + str(epoch)),
                writer=writer)

        if (epoch + 1) % args.save_per_epochs == 0:
            os.makedirs(
                os.path.join(args.out_dir, 'models',
                             args.model_name + '_' + current_time, str(epoch)))
            torch.save(
                G_AB.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'ab.pt'))
            torch.save(
                G_BA.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'ba.pt'))
            torch.save(
                D_A.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'da.pt'))
            torch.save(
                D_B.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'db.pt'))
            torch.save(
                segmen_B.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'semsg.pt'))

        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
Exemplo n.º 17
0
def train_from_mask():
    
	#load best fitness binary masks 
    mask_input_A2B=np.loadtxt("/cache/GA/txt/best_fitness_A2B.txt")
    mask_input_B2A=np.loadtxt("/cache/GA/txt/best_fitness_B2A.txt")


    cfg_mask_A2B=compute_layer_mask(mask_input_A2B,mask_chns)
    cfg_mask_B2A=compute_layer_mask(mask_input_B2A,mask_chns)
    

    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netG_A2B = Generator(opt.output_nc, opt.input_nc)
    model_A2B = Generator_Prune(cfg_mask_A2B)
    model_B2A = Generator_Prune(cfg_mask_B2A)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    
 



    netG_A2B.load_state_dict(torch.load('/cache/log/output/netG_A2B.pth'))
    netG_B2A.load_state_dict(torch.load('/cache/log/output/netG_B2A.pth'))
    
    netD_A.load_state_dict(torch.load('/cache/log/output/netD_A.pth'))
    netD_B.load_state_dict(torch.load('/cache/log/output/netD_B.pth'))
     
      



    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    

    
    
    layer_id_in_cfg=0
    start_mask=torch.ones(3)
    end_mask=cfg_mask_A2B[layer_id_in_cfg]
    
    for [m0, m1] in zip(netG_A2B.modules(), model_A2B.modules()):
  
        if isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            
            m1.bias.data =m0.bias.data[idx1.tolist()].clone()
            
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_A2B):  # do not change in Final FC
                end_mask = cfg_mask_A2B[layer_id_in_cfg]
                print(layer_id_in_cfg)
        elif isinstance(m0, nn.ConvTranspose2d):
            print('Into ConvTranspose...')
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        

            w1 = m0.weight.data[idx0.tolist(),:, :, :].clone()
            w1 = w1[:,idx1.tolist(), :, :].clone()
            m1.weight.data = w1.clone()
            m1.bias.data =m0.bias.data[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_A2B):  
                end_mask = cfg_mask_A2B[layer_id_in_cfg] 

    layer_id_in_cfg=0
    start_mask=torch.ones(3)
    end_mask=cfg_mask_B2A[layer_id_in_cfg]
    
    for [m0, m1] in zip(netG_B2A.modules(), model_B2A.modules()):
  
        if isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            
            m1.bias.data =m0.bias.data[idx1.tolist()].clone()
            
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_B2A):  
                end_mask = cfg_mask_B2A[layer_id_in_cfg]
                print(layer_id_in_cfg)
        elif isinstance(m0, nn.ConvTranspose2d):
            print('Into ConvTranspose...')
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        
            w1 = m0.weight.data[idx0.tolist(),:, :, :].clone()
            w1 = w1[:,idx1.tolist(), :, :].clone()
            m1.weight.data = w1.clone()
            m1.bias.data =m0.bias.data[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_B2A):  
                end_mask = cfg_mask_B2A[layer_id_in_cfg] 

    
    
         
     # Dataset loader
    
    netD_A=torch.nn.DataParallel(netD_A).cuda()
    netD_B=torch.nn.DataParallel(netD_B).cuda()
    model_A2B=torch.nn.DataParallel(model_A2B).cuda()
    model_B2A=torch.nn.DataParallel(model_B2A).cuda()   

    
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()        
        
    lamda_loss_ID=5.0
    lamda_loss_G=1.0
    lamda_loss_cycle=10.0
    optimizer_G = torch.optim.Adam(itertools.chain(model_A2B.parameters(), model_B2A.parameters()),
                                lr=opt.lr, betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    
    transforms_ = [ 
           transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
                transforms.RandomCrop(opt.size), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
    

    
    dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True,mode='train'), batch_size=opt.batchSize, shuffle=True,drop_last=True)
    

    
    
    
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

        # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

        ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = model_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)*lamda_loss_ID #initial 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = model_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)*lamda_loss_ID #initial 5.0

            # GAN loss
            fake_B = model_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)*lamda_loss_G  #initial 1.0

            fake_A = model_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)*lamda_loss_G #initial 1.0

            # Cycle loss
            recovered_A = model_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*lamda_loss_cycle  #initial 10.0

            recovered_B = model_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*lamda_loss_cycle  #initial 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()
            
            optimizer_G.step()
            
            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)
        
            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            
        
        print("epoch:%d  Loss G:%4f  LossID_A:%4f LossID_B:%4f  Loss_G_A2B:%4f  Loss_G_B2A:%4f  Loss_Cycle_ABA:%4f  Loss_Cycle_BAB:%4f "%(epoch,loss_G,loss_identity_A, loss_identity_B, loss_GAN_A2B, loss_GAN_B2A, loss_cycle_ABA, loss_cycle_BAB))

         # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
        
        if epoch%20==0:

        # Save models checkpoints
            torch.save(model_A2B.module.state_dict(), '/cache/log/output/A2B_%d.pth'%(epoch))
            torch.save(model_B2A.module.state_dict(), '/cache/log/output/B2A_%d.pth'%(epoch))
Exemplo n.º 18
0
# Lossess
criterion_GAN = torch.nn.MSELoss()  # Adversarial Loss
criterion_cycle = torch.nn.L1Loss()  # Cyclic consistency loss


# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                               lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(
    netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(
    netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                       requires_grad=False)  # real
target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                       requires_grad=False)  # fake

fake_A_buffer = ReplayBuffer()
Exemplo n.º 19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=200,
                        help='number of epochs of training')
    parser.add_argument('--batchSize',
                        type=int,
                        default=1,
                        help='size of the batches')
    parser.add_argument('--dataroot',
                        type=str,
                        default='datasets/data/',
                        help='root directory of the dataset')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument(
        '--decay_epoch',
        type=int,
        default=100,
        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size',
                        type=int,
                        default=256,
                        help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc',
                        type=int,
                        default=3,
                        help='number of channels of input data')
    parser.add_argument('--output_nc',
                        type=int,
                        default=3,
                        help='number of channels of output data')
    parser.add_argument('--cuda',
                        action='store_true',
                        help='use GPU computation')
    parser.add_argument(
        '--n_cpu',
        type=int,
        default=8,
        help='number of cpu threads to use during batch generation')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    ###### Definition of variables ######
    # Networks
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    if opt.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Dataset loader
    transforms_ = [
        transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(opt.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=opt.n_cpu)

    # Loss plot
    logger = Logger(opt.n_epochs, len(dataloader))
    ###################################

    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B) * 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A) * 5.0

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            # Progress report (http://localhost:8097)
            logger.log(
                {
                    'loss_G': loss_G,
                    'loss_G_identity': (loss_identity_A + loss_identity_B),
                    'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                    'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB),
                    'loss_D': (loss_D_A + loss_D_B)
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_A,
                    'fake_B': fake_B
                })

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints
        torch.save(netG_A2B.state_dict(), 'output/netG_A2B.pth')
        torch.save(netG_B2A.state_dict(), 'output/netG_B2A.pth')
        torch.save(netD_A.state_dict(), 'output/netD_A.pth')
        torch.save(netD_B.state_dict(), 'output/netD_B.pth')
Exemplo n.º 20
0
    def train(self):
        num_channels = self.config.NUM_CHANNELS
        use_cuda = self.config.USE_CUDA
        lr = self.config.LEARNING_RATE

        # Networks
        netG_A2B = Generator(num_channels)
        netG_B2A = Generator(num_channels)
        netD_A = Discriminator(num_channels)
        netD_B = Discriminator(num_channels)

        #netG_A2B = Generator_BN(num_channels)
        #netG_B2A = Generator_BN(num_channels)
        #netD_A = Discriminator_BN(num_channels)
        #netD_B = Discriminator_BN(num_channels)

        if use_cuda:
            netG_A2B.cuda()
            netG_B2A.cuda()
            netD_A.cuda()
            netD_B.cuda()

        netG_A2B.apply(weights_init_normal)
        netG_B2A.apply(weights_init_normal)
        netD_A.apply(weights_init_normal)
        netD_B.apply(weights_init_normal)

        criterion_GAN = torch.nn.BCELoss()
        criterion_cycle = torch.nn.L1Loss()
        criterion_identity = torch.nn.L1Loss()

        optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                       lr=lr, betas=(0.5, 0.999))
        optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=lr, betas=(0.5, 0.999))
        optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=lr, betas=(0.5, 0.999))

        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(self.config.EPOCH, 0,
                                                                                           self.config.EPOCH//2).step)

        # Inputs & targets memory allocation
        #Tensor = LongTensor if use_cuda else torch.Tensor
        batch_size = self.config.BATCH_SIZE
        height, width, channels = self.config.INPUT_SHAPE

        input_A = FloatTensor(batch_size, channels, height, width)
        input_B = FloatTensor(batch_size, channels, height, width)
        target_real = Variable(FloatTensor(batch_size).fill_(1.0), requires_grad=False)
        target_fake = Variable(FloatTensor(batch_size).fill_(0.0), requires_grad=False)

        fake_A_buffer = ReplayBuffer()
        fake_B_buffer = ReplayBuffer()

        transforms_ = [transforms.RandomCrop((height, width)),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]

        dataloader = DataLoader(ImageDataset(self.config.DATA_DIR, self.config.DATASET_A, self.config.DATASET_B,
                                             transforms_=transforms_, unaligned=True),
                                             batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
        # Loss plot
        logger = Logger(self.config.EPOCH, len(dataloader))

        now = datetime.datetime.now()
        datetime_sequence = "{0}{1:02d}{2:02d}_{3:02}{4:02d}".format(str(now.year)[-2:], now.month, now.day ,
                                                                    now.hour, now.minute)

        output_name_1 = self.config.DATASET_A + "2" + self.config.DATASET_B
        output_name_2 = self.config.DATASET_B + "2" + self.config.DATASET_A

        experiment_dir = os.path.join(self.config.RESULT_DIR, datetime_sequence)

        sample_output_dir_1 = os.path.join(experiment_dir, "sample", output_name_1)
        sample_output_dir_2 = os.path.join(experiment_dir, "sample", output_name_2)
        weights_output_dir_1 = os.path.join(experiment_dir, "weights", output_name_1)
        weights_output_dir_2 = os.path.join(experiment_dir, "weights", output_name_2)
        weights_output_dir_resume = os.path.join(experiment_dir, "weights", "resume")

        os.makedirs(sample_output_dir_1, exist_ok=True)
        os.makedirs(sample_output_dir_2, exist_ok=True)
        os.makedirs(weights_output_dir_1, exist_ok=True)
        os.makedirs(weights_output_dir_2, exist_ok=True)
        os.makedirs(weights_output_dir_resume, exist_ok=True)

        counter = 0

        for epoch in range(self.config.EPOCH):
            """
            logger.loss_df.to_csv(os.path.join(experiment_dir,
                                 self.config.DATASET_A + "_"
                                 + self.config.DATASET_B + ".csv"),
                    index=False)
            """
            if epoch % 100 == 0:
                torch.save(netG_A2B.state_dict(), os.path.join(weights_output_dir_1, str(epoch).zfill(4) + 'netG_A2B.pth'))
                torch.save(netG_B2A.state_dict(), os.path.join(weights_output_dir_2, str(epoch).zfill(4) + 'netG_B2A.pth'))
                torch.save(netD_A.state_dict(), os.path.join(weights_output_dir_1, str(epoch).zfill(4) + 'netD_A.pth'))
                torch.save(netD_B.state_dict(), os.path.join(weights_output_dir_2, str(epoch).zfill(4) + 'netD_B.pth'))

            for i, batch in enumerate(dataloader):
                # Set model input
                real_A = Variable(input_A.copy_(batch['A']))
                real_B = Variable(input_B.copy_(batch['B']))

                ###### Generators A2B and B2A ######
                optimizer_G.zero_grad()

                # GAN loss
                fake_B = netG_A2B(real_A)
                pred_fake_B = netD_B(fake_B)
                loss_GAN_A2B = criterion_GAN(pred_fake_B, target_real)

                fake_A = netG_B2A(real_B)
                pred_fake_A = netD_A(fake_A)
                loss_GAN_B2A = criterion_GAN(pred_fake_A, target_real)

                # Cycle loss
                recovered_A = netG_B2A(fake_B)
                loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

                recovered_B = netG_A2B(fake_A)
                loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

                # Total loss
                loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
                loss_G.backward()

                optimizer_G.step()
                ###################################

                ###### Discriminator A ######
                optimizer_D_A.zero_grad()

                # Real loss
                pred_A = netD_A(real_A)
                loss_D_real = criterion_GAN(pred_A, target_real)

                # Fake loss
                fake_A_ = fake_A_buffer.push_and_pop(fake_A)
                pred_fake = netD_A(fake_A_.detach())
                loss_D_fake = criterion_GAN(pred_fake, target_fake)

                # Total loss
                loss_D_A = (loss_D_real + loss_D_fake) * 0.5
                loss_D_A.backward()

                optimizer_D_A.step()
                ###################################

                ###### Discriminator B ######
                optimizer_D_B.zero_grad()

                # Real loss
                pred_B = netD_B(real_B)
                loss_D_real = criterion_GAN(pred_B, target_real)

                # Fake loss
                fake_B_ = fake_B_buffer.push_and_pop(fake_B)
                pred_fake = netD_B(fake_B_.detach())
                loss_D_fake = criterion_GAN(pred_fake, target_fake)

                # Total loss
                loss_D_B = (loss_D_real + loss_D_fake) * 0.5
                loss_D_B.backward()

                optimizer_D_B.step()

                # Progress report (http://localhost:8097)
                logger.log({'loss_G': loss_G,
                            'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                            'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)},
                           images={'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B})

                if counter % 500 == 0:
                    real_A_sample = real_A.cpu().detach().numpy()[0]
                    pred_A_sample = fake_A.cpu().detach().numpy()[0]
                    real_B_sample = real_B.cpu().detach().numpy()[0]
                    pred_B_sample = fake_B.cpu().detach().numpy()[0]
                    combine_sample_1 = np.concatenate([real_A_sample, pred_B_sample], axis=2)
                    combine_sample_2 = np.concatenate([real_B_sample, pred_A_sample], axis=2)

                    file_1 = "{0}_{1}.jpg".format(epoch, counter)
                    output_sample_image(os.path.join(sample_output_dir_1, file_1), combine_sample_1)
                    file_2 = "{0}_{1}.jpg".format(epoch, counter)
                    output_sample_image(os.path.join(sample_output_dir_2, file_2), combine_sample_2)

                counter += 1


            # Update learning rates
            lr_scheduler_G.step()
            lr_scheduler_D_A.step()
            lr_scheduler_D_B.step()

        torch.save(netG_A2B.state_dict(), os.path.join(weights_output_dir_1, str(self.config.EPOCH).zfill(4) + 'netG_A2B.pth'))
        torch.save(netG_B2A.state_dict(), os.path.join(weights_output_dir_2, str(self.config.EPOCH).zfill(4) + 'netG_B2A.pth'))
        torch.save(netD_A.state_dict(), os.path.join(weights_output_dir_1, str(self.config.EPOCH).zfill(4) + 'netD_A.pth'))
        torch.save(netD_B.state_dict(), os.path.join(weights_output_dir_2, str(self.config.EPOCH).zfill(4) + 'netD_B.pth'))
Exemplo n.º 21
0
from utils import ReplayBuffer, LambdaLR, sample_images

#load the args
args = TrainOptions().parse()
# Calculate output of size discriminator (PatchGAN)
patch = (1, args.img_height//(2**args.n_D_layers) - 2 , args.img_width//(2**args.n_D_layers) - 2)

# Initialize generator and discriminator
G__AB, D__B, G__BA, D__A = Create_nets(args)

# Loss functions
criterion_GAN, criterion_cycle, criterion_identity = Get_loss_func(args)
# Optimizers
optimizer_G, optimizer_D_B, optimizer_D_A = Get_optimizers(args, G__AB, G__BA, D__B, D__A )
# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.epoch_num, args.epoch_start, args.decay_epoch).step)

# Configure dataloaders
train_dataloader,test_dataloader,_ = Get_dataloader(args)

# Buffers of previously generated samples
fake_Y_A_buffer = ReplayBuffer()
fake_X_B_buffer = ReplayBuffer()


# ----------
#  Training
# ----------
Exemplo n.º 22
0
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))

lambda_LR=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch)
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_LR.step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lambda_LR.step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lambda_LR.step)

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
Exemplo n.º 23
0
def train_seg_model(args):
    # model
    model = None
    if args.model_name == "UNet":
        model = UNet(n_channels=args.in_channels, n_classes=args.class_num)
    elif args.model_name == "PSP":
        model = pspnet.PSPNet(n_classes=19, input_size=(512, 512))
        model.load_pretrained_model(
            model_path="./segnet/pspnet/pspnet101_cityscapes.caffemodel")
        model.classification = nn.Conv2d(512, args.class_num, kernel_size=1)
    else:
        raise AssertionError("Unknow modle: {}".format(args.model_name))
    model = nn.DataParallel(model)
    model.cuda()
    # optimizer
    optimizer = None
    if args.optim_name == "Adam":
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=1.0e-3)
    elif args.optim_name == "SGD":
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.init_lr,
                              momentum=0.9,
                              weight_decay=0.0005)
    else:
        raise AssertionError("Unknow optimizer: {}".format(args.optim_name))
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=LambdaLR(args.maxepoch, 0,
                                                         0).step)
    # dataloader
    train_data_dir = os.path.join(args.data_dir, args.tumor_type, "train")
    train_dloader = gen_dloader(train_data_dir,
                                args.batch_size,
                                mode="train",
                                normalize=args.normalize,
                                tumor_type=args.tumor_type)
    test_data_dir = os.path.join(args.data_dir, args.tumor_type, "val")
    val_dloader = gen_dloader(test_data_dir,
                              args.batch_size,
                              mode="val",
                              normalize=args.normalize,
                              tumor_type=args.tumor_type)

    # training
    save_model_dir = os.path.join(args.model_dir, args.tumor_type,
                                  args.session)
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
    best_dice = 0.0
    for epoch in np.arange(0, args.maxepoch):
        print('Epoch {}/{}'.format(epoch + 1, args.maxepoch))
        print('-' * 10)
        since = time.time()
        for phase in ['train', 'val']:
            if phase == 'train':
                dloader = train_dloader
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("Current LR: {:.8f}".format(param_group['lr']))
                model.train()  # Set model to training mode
            else:
                dloader = val_dloader
                model.eval()  # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            for batch_ind, (imgs, masks) in enumerate(dloader):
                inputs = Variable(imgs.cuda())
                masks = Variable(masks.cuda())
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs,
                                     masks,
                                     metrics,
                                     bce_weight=args.bce_weight)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # statistics
                epoch_samples += inputs.size(0)
            print_metrics(metrics, epoch_samples, phase)
            epoch_dice = metrics['dice'] / epoch_samples

            # deep copy the model
            if phase == 'val' and (epoch_dice > best_dice
                                   or epoch > args.maxepoch - 5):
                best_dice = epoch_dice
                best_model = copy.deepcopy(model.state_dict())
                best_model_name = "-".join([
                    args.model_name,
                    "{:03d}-{:.3f}.pth".format(epoch, best_dice)
                ])
                torch.save(best_model,
                           os.path.join(save_model_dir, best_model_name))
        time_elapsed = time.time() - since
        print('Epoch {:2d} takes {:.0f}m {:.0f}s'.format(
            epoch, time_elapsed // 60, time_elapsed % 60))
    print(
        "================================================================================"
    )
    print("Training finished...")
Exemplo n.º 24
0
    netD_A = Discriminator(3).to(device)
    netD_B = Discriminator(3).to(device)

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # optimizers and learning rate schedulers
    optimizer_G = Adam(itertools.chain(netG_A2B.parameters(),
                                       netG_B2A.parameters()),
                       lr=opt.lr,
                       betas=BETAS)
    optimizer_D_en = Adam(netD_A.parameters(), lr=opt.lr, betas=BETAS)
    optimizer_D_zh = Adam(netD_B.parameters(), lr=opt.lr, betas=BETAS)

    lr_scheduler_G = lr_scheduler.LambdaLR(optimizer_G,
                                           lr_lambda=LambdaLR(
                                               opt.n_epochs, 0,
                                               DECAY_EPOCH).step)
    lr_scheduler_D_en = lr_scheduler.LambdaLR(optimizer_D_en,
                                              lr_lambda=LambdaLR(
                                                  opt.n_epochs, 0,
                                                  DECAY_EPOCH).step)
    lr_scheduler_D_zh = lr_scheduler.LambdaLR(optimizer_D_zh,
                                              lr_lambda=LambdaLR(
                                                  opt.n_epochs, 0,
                                                  DECAY_EPOCH).step)

    train()