Esempio n. 1
0
        noise.data.normal_(0, 1)
        d = [label_c_hot_in[l] for l in lable_keys_cam_view_info[0]]
        d.append(noise)
        input_d = torch.cat(d, dim=1)
        input_d.data.resize_(batch_size, nz+sum(n_classes), 1, 1)

        with torch.no_grad():
            # sampled= netG.sampler(netG.encoder(input))
            # d= [label_c_hot_in[l] for l in lable_keys_cam_view_info[0]]
            # d.append(sampled.view(batch_size,nz))
            # input_d = torch.cat(d, dim=1)
            # input_d.data.resize_(batch_size, nz+sum(n_classes), 1, 1)
            # encode the owther view
            gen = netG.decoder(input_d)

        gen = fake_buffer.push_and_pop(gen)
        # train real
        input_white_noise = input + torch.randn(input.data.size()).cuda()*(0.5 * d_real_input_noise)
        output_f, output_c = netD(input_white_noise)
        errD_real = criterion(output_f, label.view(batch_size, 1))
        loss_lables_real = 0
        for key_l, out in zip(lable_keys_cam_view_info[0], output_c):
            l_c = criterion_c(out, label_c[key_l])
            loss_lables_real += l_c
            errD_real += l_c
        errD_real.backward()
        D_x = output_f.data.mean()

        if i % opt.showimg == 0:
            if vis is not None:
                gen_win = vis.image(gen.data[0].cpu()*0.5+0.5, win=gen_win,
Esempio n. 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=200,
                        help='number of epochs of training')
    parser.add_argument('--batchSize',
                        type=int,
                        default=1,
                        help='size of the batches')
    parser.add_argument('--dataroot',
                        type=str,
                        default='datasets/data/',
                        help='root directory of the dataset')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument(
        '--decay_epoch',
        type=int,
        default=100,
        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size',
                        type=int,
                        default=256,
                        help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc',
                        type=int,
                        default=3,
                        help='number of channels of input data')
    parser.add_argument('--output_nc',
                        type=int,
                        default=3,
                        help='number of channels of output data')
    parser.add_argument('--cuda',
                        action='store_true',
                        help='use GPU computation')
    parser.add_argument(
        '--n_cpu',
        type=int,
        default=8,
        help='number of cpu threads to use during batch generation')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    ###### Definition of variables ######
    # Networks
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    if opt.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Dataset loader
    transforms_ = [
        transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(opt.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=opt.n_cpu)

    # Loss plot
    logger = Logger(opt.n_epochs, len(dataloader))
    ###################################

    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B) * 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A) * 5.0

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            # Progress report (http://localhost:8097)
            logger.log(
                {
                    'loss_G': loss_G,
                    'loss_G_identity': (loss_identity_A + loss_identity_B),
                    'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                    'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB),
                    'loss_D': (loss_D_A + loss_D_B)
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_A,
                    'fake_B': fake_B
                })

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints
        torch.save(netG_A2B.state_dict(), 'output/netG_A2B.pth')
        torch.save(netG_B2A.state_dict(), 'output/netG_B2A.pth')
        torch.save(netD_A.state_dict(), 'output/netD_A.pth')
        torch.save(netD_B.state_dict(), 'output/netD_B.pth')
Esempio n. 3
0
def caculate_fitness(mask_input_A2B, mask_input_B2A, gpu_id, fitness_id,
                     A2B_or_B2A):

    torch.cuda.set_device(gpu_id)
    #print("GPU_ID is%d\n"%(gpu_id))

    model_A2B = Generator(opt.input_nc, opt.output_nc)
    model_B2A = Generator(opt.input_nc, opt.output_nc)

    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    netD_A.cuda(gpu_id)
    netD_B.cuda(gpu_id)
    model_A2B.cuda(gpu_id)
    model_B2A.cuda(gpu_id)

    model_A2B.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))
    model_B2A.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))
    netD_A.load_state_dict(torch.load('/cache/models/netD_A.pth'))
    netD_B.load_state_dict(torch.load('/cache/models/netD_B.pth'))

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    fitness = 0
    cfg_mask_A2B = compute_layer_mask(mask_input_A2B, mask_chns)
    cfg_mask_B2A = compute_layer_mask(mask_input_B2A, mask_chns)
    cfg_full_mask_A2B = [y for x in cfg_mask_A2B for y in x]
    cfg_full_mask_A2B = np.array(cfg_full_mask_A2B)
    cfg_full_mask_B2A = [y for x in cfg_mask_B2A for y in x]
    cfg_full_mask_B2A = np.array(cfg_full_mask_B2A)
    cfg_id = 0
    start_mask = np.ones(3)
    end_mask = cfg_mask_A2B[cfg_id]

    for m in model_A2B.modules():
        if isinstance(m, nn.Conv2d):

            #print("conv2d")
            #print(m.weight.data.shape)
            #out_channels = m.weight.data.shape[0]
            mask = np.ones(m.weight.data.shape)

            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))

            mask[:, idx0.tolist(), :, :] = 0
            mask[idx1.tolist(), :, :, :] = 0
            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)

            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            idx_mask = np.argwhere(np.asarray(np.ones(mask.shape) - mask))

            m.weight.data[:, idx0.tolist(), :, :].requires_grad = False
            m.weight.data[idx1.tolist(), :, :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            if cfg_id < len(cfg_mask):
                end_mask = cfg_mask_A2B[cfg_id]
            continue
        elif isinstance(m, nn.ConvTranspose2d):

            mask = np.ones(m.weight.data.shape)
            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask

            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))

            mask[idx0.tolist(), :, :, :] = 0

            mask[:, idx1.tolist(), :, :] = 0

            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)
            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            m.weight.data[idx0.tolist(), :, :, :].requires_grad = False
            m.weight.data[:, idx1.tolist(), :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            end_mask = cfg_mask_A2B[cfg_id]
            continue

    cfg_id = 0
    start_mask = np.ones(3)
    end_mask = cfg_mask_B2A[cfg_id]

    for m in model_B2A.modules():
        if isinstance(m, nn.Conv2d):

            #print("conv2d")
            #print(m.weight.data.shape)
            #out_channels = m.weight.data.shape[0]
            mask = np.ones(m.weight.data.shape)

            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))

            mask[:, idx0.tolist(), :, :] = 0
            mask[idx1.tolist(), :, :, :] = 0
            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)

            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            idx_mask = np.argwhere(np.asarray(np.ones(mask.shape) - mask))

            m.weight.data[:, idx0.tolist(), :, :].requires_grad = False
            m.weight.data[idx1.tolist(), :, :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            if cfg_id < len(cfg_mask):
                end_mask = cfg_mask_B2A[cfg_id]
            continue
        elif isinstance(m, nn.ConvTranspose2d):

            mask = np.ones(m.weight.data.shape)
            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask

            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))

            mask[idx0.tolist(), :, :, :] = 0

            mask[:, idx1.tolist(), :, :] = 0

            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)
            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            m.weight.data[idx0.tolist(), :, :, :].requires_grad = False
            m.weight.data[:, idx1.tolist(), :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            end_mask = cfg_mask_B2A[cfg_id]
            continue

    # Dataset loader
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    lamda_loss_ID = 5.0
    lamda_loss_G = 1.0
    lamda_loss_cycle = 10.0
    optimizer_G = torch.optim.Adam(itertools.chain(
        filter(lambda p: p.requires_grad, model_A2B.parameters()),
        filter(lambda p: p.requires_grad, model_B2A.parameters())),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    transforms_ = [
        transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(opt.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True,
                                         mode='train'),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            drop_last=True)

    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = model_A2B(real_B)
            loss_identity_B = criterion_identity(
                same_B, real_B) * lamda_loss_ID  #initial 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = model_B2A(real_A)
            loss_identity_A = criterion_identity(
                same_A, real_A) * lamda_loss_ID  #initial 5.0

            # GAN loss
            fake_B = model_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(
                pred_fake, target_real) * lamda_loss_G  #initial 1.0

            fake_A = model_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(
                pred_fake, target_real) * lamda_loss_G  #initial 1.0

            # Cycle loss
            recovered_A = model_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(
                recovered_A, real_A) * lamda_loss_cycle  #initial 10.0

            recovered_B = model_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(
                recovered_B, real_B) * lamda_loss_cycle  #initial 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()

    with torch.no_grad():

        transforms_ = [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        dataloader = DataLoader(ImageDataset(opt.dataroot,
                                             transforms_=transforms_,
                                             mode='val'),
                                batch_size=opt.batchSize,
                                shuffle=False,
                                drop_last=True)

        Loss_resemble_G = 0
        if A2B_or_B2A == 'A2B':
            netG_A2B = Generator(opt.output_nc, opt.input_nc)
            netD_B = Discriminator(opt.output_nc)

            netG_A2B.cuda(gpu_id)
            netD_B.cuda(gpu_id)

            model_A2B.eval()
            netD_B.eval()
            netG_A2B.eval()

            netD_B.load_state_dict(torch.load('/cache/models/netD_B.pth'))
            netG_A2B.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))

            for i, batch in enumerate(dataloader):

                real_A = Variable(input_A.copy_(batch['A']))

                fake_B = model_A2B(real_A)
                fake_B_full_model = netG_A2B(real_A)
                recovered_A = model_B2A(fake_B)

                pred_fake = netD_B(fake_B.detach())

                pred_fake_full = netD_B(fake_B_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                cycle_loss = criterion_cycle(recovered_A,
                                             real_A) * lamda_loss_cycle
                Loss_resemble_G = Loss_resemble_G + loss_D_fake + cycle_loss

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask_A2B.shape) -
                cfg_full_mask_A2B) * lambda_prune

            print('A2B')
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask_A2B)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask_A2B.shape) - cfg_full_mask_A2B)))
            print("fitness is %f \n" % (fitness))

            current_fitness_A2B[fitness_id] = fitness.item()

        if A2B_or_B2A == 'B2A':
            netG_B2A = Generator(opt.output_nc, opt.input_nc)
            netD_A = Discriminator(opt.output_nc)

            netG_B2A.cuda(gpu_id)
            netD_A.cuda(gpu_id)

            model_B2A.eval()
            netD_A.eval()
            netG_B2A.eval()

            netD_A.load_state_dict(torch.load('/cache/models/netD_A.pth'))
            netG_B2A.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))

            for i, batch in enumerate(dataloader):

                real_B = Variable(input_B.copy_(batch['B']))

                fake_A = model_B2A(real_B)
                fake_A_full_model = netG_B2A(real_B)
                recovered_B = model_A2B(fake_A)

                pred_fake = netD_A(fake_A.detach())

                pred_fake_full = netD_A(fake_A_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                cycle_loss = criterion_cycle(recovered_B,
                                             real_B) * lamda_loss_cycle
                Loss_resemble_G = Loss_resemble_G + loss_D_fake + cycle_loss

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask_B2A.shape) -
                cfg_full_mask_B2A) * lambda_prune

            print('B2A')
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask_B2A)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask_B2A.shape) - cfg_full_mask_B2A)))
            print("fitness is %f \n" % (fitness))

            current_fitness_B2A[fitness_id] = fitness.item()
Esempio n. 4
0
def train_from_mask():
    
	#load best fitness binary masks 
    mask_input_A2B=np.loadtxt("/cache/GA/txt/best_fitness_A2B.txt")
    mask_input_B2A=np.loadtxt("/cache/GA/txt/best_fitness_B2A.txt")


    cfg_mask_A2B=compute_layer_mask(mask_input_A2B,mask_chns)
    cfg_mask_B2A=compute_layer_mask(mask_input_B2A,mask_chns)
    

    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netG_A2B = Generator(opt.output_nc, opt.input_nc)
    model_A2B = Generator_Prune(cfg_mask_A2B)
    model_B2A = Generator_Prune(cfg_mask_B2A)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    
 



    netG_A2B.load_state_dict(torch.load('/cache/log/output/netG_A2B.pth'))
    netG_B2A.load_state_dict(torch.load('/cache/log/output/netG_B2A.pth'))
    
    netD_A.load_state_dict(torch.load('/cache/log/output/netD_A.pth'))
    netD_B.load_state_dict(torch.load('/cache/log/output/netD_B.pth'))
     
      



    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    

    
    
    layer_id_in_cfg=0
    start_mask=torch.ones(3)
    end_mask=cfg_mask_A2B[layer_id_in_cfg]
    
    for [m0, m1] in zip(netG_A2B.modules(), model_A2B.modules()):
  
        if isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            
            m1.bias.data =m0.bias.data[idx1.tolist()].clone()
            
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_A2B):  # do not change in Final FC
                end_mask = cfg_mask_A2B[layer_id_in_cfg]
                print(layer_id_in_cfg)
        elif isinstance(m0, nn.ConvTranspose2d):
            print('Into ConvTranspose...')
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        

            w1 = m0.weight.data[idx0.tolist(),:, :, :].clone()
            w1 = w1[:,idx1.tolist(), :, :].clone()
            m1.weight.data = w1.clone()
            m1.bias.data =m0.bias.data[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_A2B):  
                end_mask = cfg_mask_A2B[layer_id_in_cfg] 

    layer_id_in_cfg=0
    start_mask=torch.ones(3)
    end_mask=cfg_mask_B2A[layer_id_in_cfg]
    
    for [m0, m1] in zip(netG_B2A.modules(), model_B2A.modules()):
  
        if isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            
            m1.bias.data =m0.bias.data[idx1.tolist()].clone()
            
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_B2A):  
                end_mask = cfg_mask_B2A[layer_id_in_cfg]
                print(layer_id_in_cfg)
        elif isinstance(m0, nn.ConvTranspose2d):
            print('Into ConvTranspose...')
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        
            w1 = m0.weight.data[idx0.tolist(),:, :, :].clone()
            w1 = w1[:,idx1.tolist(), :, :].clone()
            m1.weight.data = w1.clone()
            m1.bias.data =m0.bias.data[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_B2A):  
                end_mask = cfg_mask_B2A[layer_id_in_cfg] 

    
    
         
     # Dataset loader
    
    netD_A=torch.nn.DataParallel(netD_A).cuda()
    netD_B=torch.nn.DataParallel(netD_B).cuda()
    model_A2B=torch.nn.DataParallel(model_A2B).cuda()
    model_B2A=torch.nn.DataParallel(model_B2A).cuda()   

    
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()        
        
    lamda_loss_ID=5.0
    lamda_loss_G=1.0
    lamda_loss_cycle=10.0
    optimizer_G = torch.optim.Adam(itertools.chain(model_A2B.parameters(), model_B2A.parameters()),
                                lr=opt.lr, betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    
    transforms_ = [ 
           transforms.Resize(int(opt.size*1.12), Image.BICUBIC), 
                transforms.RandomCrop(opt.size), 
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
    

    
    dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True,mode='train'), batch_size=opt.batchSize, shuffle=True,drop_last=True)
    

    
    
    
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

        # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

        ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = model_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)*lamda_loss_ID #initial 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = model_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)*lamda_loss_ID #initial 5.0

            # GAN loss
            fake_B = model_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)*lamda_loss_G  #initial 1.0

            fake_A = model_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)*lamda_loss_G #initial 1.0

            # Cycle loss
            recovered_A = model_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*lamda_loss_cycle  #initial 10.0

            recovered_B = model_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*lamda_loss_cycle  #initial 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()
            
            optimizer_G.step()
            
            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)
        
            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            
        
        print("epoch:%d  Loss G:%4f  LossID_A:%4f LossID_B:%4f  Loss_G_A2B:%4f  Loss_G_B2A:%4f  Loss_Cycle_ABA:%4f  Loss_Cycle_BAB:%4f "%(epoch,loss_G,loss_identity_A, loss_identity_B, loss_GAN_A2B, loss_GAN_B2A, loss_cycle_ABA, loss_cycle_BAB))

         # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
        
        if epoch%20==0:

        # Save models checkpoints
            torch.save(model_A2B.module.state_dict(), '/cache/log/output/A2B_%d.pth'%(epoch))
            torch.save(model_B2A.module.state_dict(), '/cache/log/output/B2A_%d.pth'%(epoch))
Esempio n. 5
0
    loss_D_aux = 0
    ###### Discriminator S ######
    optimizer_D_s.zero_grad()

    # Real loss
    pred_f_s, aux_f_s = D_s(f_s.detach())
    loss_D_real = criterion_GAN(
        pred_f_s, real_label[0:int(len(real_label) / 2)].view(-1, 1))

    if MODELTYPE == 'D' or MODELTYPE == 'E':
        pass
    else:
        loss_D_aux = criterion_CE(aux_f_s, y_s)

    # Fake loss
    f_ts = f_ts_buffer.push_and_pop(f_ts)
    pred_f_ts, aux_f_ts = D_s(f_ts.detach())
    loss_D_fake = criterion_GAN(
        pred_f_ts, fake_label[0:int(len(fake_label) / 2)].view(-1, 1))

    # Total loss
    if MODELTYPE == 'D' or MODELTYPE == 'E':
        loss_D_s = loss_D_real + loss_D_fake
    else:
        loss_D_s = loss_D_real + loss_D_fake + loss_D_aux
    loss_D_s.backward()

    optimizer_D_s.step()
    ###################################

    ###### Discriminator t ######
Esempio n. 6
0
        loss_P = content_loss + style_loss  # + tv_loss
        ########## loss C Perceptual ##########

        ########## total loss ##########
        loss_G = torch.mean(cond0 * loss_RC + cond1 * loss_GANB +
                            cond2 * loss_P)
        # loss_G = loss_P
        loss_G.backward()
        optimizer_G.step()

        ##########  Discriminator B ##########
        optimizer_D_B.zero_grad()

        # Fake loss
        fake_B, cond_r = fake_A_buffer.push_and_pop((out_im, cond))
        pred_fake = netD_B(fake_B.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
        # loss_D_fake = criterion_MSE(pred_fake.squeeze(), target_fake)

        # Real loss
        pred_real = netD_B(real_B)
        # loss_D_real = criterion_MSE(pred_real.squeeze(), target_real)
        loss_D_real = criterionGAN(pred_real, True)

        # Total loss
        loss_D_A = torch.mean(cond_r[:, 1] * (loss_D_real + loss_D_fake)) * 0.5
        loss_D_A.backward()

        optimizer_D_B.step()
Esempio n. 7
0
        # Real loss
        pred_real = netD.forward(real_B)

        loss_D_real = GANloss(pred_real, True)

        # print(pred_real.size()[0])
        # target_real = Variable(Tensor(pred_real.size()[0]).fill_(1.0), requires_grad=False)
        # print(target_real.shape)
        # target_fake = Variable(Tensor(pred_real.size()[0]).fill_(0.0), requires_grad=False)

        # loss_D_real = criterion_GAN(pred_real, target_real)
 
        # Fake loss
        fake_B1= fake_B
        fake_B = fake_B_buffer.push_and_pop(fake_B)
        pred_fake = netD.forward(fake_B.detach())
        loss_D_fake = GANloss(pred_fake, False)        

        # Total loss
        loss_D = (loss_D_real + loss_D_fake)*0.5
        loss_D.backward()

        optimizer_D.step()
        ###################################

        # Progress report (http://localhost:8097)
        logger.log({'loss': loss,  'loss_G': loss_G, 'loss_feature': loss_feature, 'loss_idt': loss_idt,
                    'loss_VGG': loss_VGG,
                    'ssim_A': ssim_A,                    
                    'loss_D': loss_D}, 
Esempio n. 8
0
        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        
        optimizer_G.step()
        ###################################

        ###### Discriminator A ######
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = netD_A(real_A)
        loss_D_real = criterion_GAN(pred_real, target_real)

        # Fake loss
        fake_A = fake_A_buffer.push_and_pop(fake_A)
        pred_fake = netD_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_A = (loss_D_real + loss_D_fake)*0.5
        loss_D_A.backward()

        optimizer_D_A.step()
        ###################################

        ###### Discriminator B ######
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = netD_B(real_B)
Esempio n. 9
0
def main():
    # Command-line parser
    parser = argparse.ArgumentParser(
        description=
        "This is a pytorch implementation of CycleGAN. Please refer to the following arguments."
    )
    parser.add_argument('--batch_size',
                        default=1,
                        type=int,
                        help='Size of a mini-batch. Default: 1')
    parser.add_argument('--cuda',
                        action="store_true",
                        help="Turn on the cuda option.")
    parser.add_argument(
        '--data_root',
        type=str,
        default='./train',
        help='Root directory to the input dataset. Default: ./train')
    parser.add_argument('--dataset',
                        type=str,
                        default="fruit2rotten",
                        help='Name of the dataset. Default: fruit2rotten)')
    parser.add_argument(
        '--decay_epochs',
        type=int,
        default=80,
        help=
        "epoch to start linearly decaying the learning rate to 0. Default: 80")
    parser.add_argument('--epochs',
                        default=100,
                        type=int,
                        help="Number of epochs. Default: 100")
    parser.add_argument('--image_size',
                        type=int,
                        default=256,
                        help='Size of the image. Default: 256')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.0002,
                        help='Learning rate. Default: 0.0002')
    args = parser.parse_args()
    print('****Preparing training with following options****')
    time.sleep(0.2)

    # Cuda option
    if torch.cuda.is_available() and not args.cuda:
        print("Cuda device found. Turning on cuda...")
        args.cuda = True
        time.sleep(0.2)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    # Random seed to initialize the random state
    seed = random.randint(1, 10000)
    torch.manual_seed(seed)
    print(f'Random Seed: {seed}')

    print(f'Batch size: {args.batch_size}')
    print(f'Cuda: {args.cuda}')
    print(f'Data root: {args.data_root}')
    print(f'Dataset: {args.dataset}')
    print(f'Decay epochs: {args.decay_epochs}')
    print(f'Epochs: {args.epochs}')
    print(f'Image size: {args.image_size}')
    print(f'Learning rate: {args.learning_rate}')
    time.sleep(0.2)

    # Create directory
    try:
        os.makedirs(args.data_root)
    except OSError:
        pass

    # Weights
    try:
        os.makedirs(os.path.join("weights", args.dataset))
    except OSError:
        pass

    # Load dataset
    dataset = LoadDataset(data_root=os.path.join(args.data_root, args.dataset),
                          img_size=args.image_size)

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             pin_memory=True)

    # Create models
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)
    D_A = Discriminator().to(device)
    D_B = Discriminator().to(device)

    G_A2B.apply(weights_init)
    G_B2A.apply(weights_init)
    D_A.apply(weights_init)
    D_B.apply(weights_init)

    # Loss function
    gan_loss = torch.nn.MSELoss().to(device)
    cycle_loss = torch.nn.L1Loss().to(device)
    identity_loss = torch.nn.L1Loss().to(device)

    discriminator_losses = []
    generator_losses = []
    cycle_losses = []
    gan_losses = []
    identity_losses = []

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_A2B.parameters(),
                                                   G_B2A.parameters()),
                                   lr=args.learning_rate)
    optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=args.learning_rate)
    optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=args.learning_rate)

    # Learning rate schedulers that will implement quadratic decreasing of the learning rate
    quadraticLR = QuadraticLR(args.epochs, 0, args.decay_epochs).step
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                       lr_lambda=quadraticLR)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A,
                                                         lr_lambda=quadraticLR)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B,
                                                         lr_lambda=quadraticLR)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    for epoch in range(0, args.epochs):
        progress = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, data in progress:
            # get batch size data
            real_image_A = data["A"].to(device)
            real_image_B = data["B"].to(device)
            batch_size = real_image_A.size(0)

            # Fake: 0, Real: 1
            fake_label = torch.full((batch_size, 1),
                                    0,
                                    device=device,
                                    dtype=torch.float32)
            real_label = torch.full((batch_size, 1),
                                    1,
                                    device=device,
                                    dtype=torch.float32)

            #***********************
            # 1. UPDATE GENERATORS *
            #***********************
            optimizer_G.zero_grad()

            # Identity loss
            # A = G_B2A(A)
            identity_image_A = G_B2A(real_image_A)
            loss_identity_A = identity_loss(identity_image_A,
                                            real_image_A) * 5.0
            # B = G_A2B(B)
            identity_image_B = G_A2B(real_image_B)
            loss_identity_B = identity_loss(identity_image_B,
                                            real_image_B) * 5.0

            # GAN loss
            # D_A(G_A(A))
            fake_image_A = G_B2A(real_image_B)
            fake_output_A = D_A(fake_image_A)
            loss_GAN_B2A = gan_loss(fake_output_A, real_label)
            # D_B(G_B(B))
            fake_image_B = G_A2B(real_image_A)
            fake_output_B = D_B(fake_image_B)
            loss_GAN_A2B = gan_loss(fake_output_B, real_label)

            # Cycle loss
            recovered_image_A = G_B2A(fake_image_B)
            loss_cycle_ABA = cycle_loss(recovered_image_A, real_image_A) * 10.0

            recovered_image_B = G_A2B(fake_image_A)
            loss_cycle_BAB = cycle_loss(recovered_image_B, real_image_B) * 10.0

            # Net loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

            # Update Generator
            loss_G.backward()
            optimizer_G.step()

            #***************************
            # 2. UPDATE DISCRIMINATORS *
            #***************************
            optimizer_D_A.zero_grad()
            optimizer_D_B.zero_grad()

            # Real image loss
            real_output_A = D_A(real_image_A)
            loss_D_real_A = gan_loss(real_output_A, real_label)

            real_output_B = D_B(real_image_B)
            loss_D_real_B = gan_loss(real_output_B, real_label)

            # Fake image loss
            fake_image_A = fake_A_buffer.push_and_pop(fake_image_A)
            fake_output_A = D_A(fake_image_A.detach())
            loss_D_fake_A = gan_loss(fake_output_A, fake_label)

            fake_image_B = fake_B_buffer.push_and_pop(fake_image_B)
            fake_output_B = D_B(fake_image_B.detach())
            loss_D_fake_B = gan_loss(fake_output_B, fake_label)

            # Net loss
            loss_D_A = (loss_D_real_A + loss_D_fake_A) / 2
            loss_D_B = (loss_D_real_B + loss_D_fake_B) / 2

            # Update Discriminator
            loss_D_A.backward()
            optimizer_D_A.step()

            loss_D_B.backward()
            optimizer_D_B.step()

            #*************
            # 3. Verbose *
            #*************
            progress.set_description(
                f"[{epoch}/{args.epochs - 1}][{i}/{len(dataloader) - 1}] "
                f"Loss_D: {(loss_D_A + loss_D_B).item():.4f} "
                f"Loss_G: {loss_G.item():.4f} "
                f"Loss_G_identity: {(loss_identity_A + loss_identity_B).item():.4f} "
                f"Loss_G_GAN: {(loss_GAN_A2B + loss_GAN_B2A).item():.4f} "
                f"Loss_G_cycle: {(loss_cycle_ABA + loss_cycle_BAB).item():.4f}"
            )

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

    # save last check pointing
    torch.save(G_A2B.state_dict(), f"weights/{args.dataset}/G_A2B.pth")
    torch.save(G_B2A.state_dict(), f"weights/{args.dataset}/G_B2A.pth")
    torch.save(D_A.state_dict(), f"weights/{args.dataset}/D_A.pth")
    torch.save(D_B.state_dict(), f"weights/{args.dataset}/D_B.pth")
Esempio n. 10
0
def main():
    cuda = torch.cuda.is_available()

    input_shape = (opt.channels, opt.img_height, opt.img_width)

    # Initialize generator and discriminator
    G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
    G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
    D_A = Discriminator(input_shape)
    D_B = Discriminator(input_shape)

    if cuda:
        G_AB = G_AB.cuda()
        G_BA = G_BA.cuda()
        D_A = D_A.cuda()
        D_B = D_B.cuda()
        criterion_GAN.cuda()
        criterion_cycle.cuda()
        criterion_identity.cuda()

    if opt.epoch != 0:
        # Load pretrained models
        G_AB.load_state_dict(
            torch.load("saved_models/%s/G_AB_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        G_BA.load_state_dict(
            torch.load("saved_models/%s/G_BA_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_A.load_state_dict(
            torch.load("saved_models/%s/D_A_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_B.load_state_dict(
            torch.load("saved_models/%s/D_B_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

    # Buffers of previously generated samples
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Image transformations
    transforms_ = [
        transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opt.img_height, opt.img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    # Training data loader
    dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )
    # Test data loader
    val_dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True,
                     mode="test"),
        batch_size=5,
        shuffle=True,
        num_workers=1,
    )

    def sample_images(batches_done):
        """Saves a generated sample from the test set"""
        imgs = next(iter(val_dataloader))
        G_AB.eval()
        G_BA.eval()
        real_A = Variable(imgs["A"].type(Tensor))
        fake_B = G_AB(real_A)
        real_B = Variable(imgs["B"].type(Tensor))
        fake_A = G_BA(real_B)
        # Arange images along x-axis
        real_A = make_grid(real_A, nrow=5, normalize=True)
        real_B = make_grid(real_B, nrow=5, normalize=True)
        fake_A = make_grid(fake_A, nrow=5, normalize=True)
        fake_B = make_grid(fake_B, nrow=5, normalize=True)
        # Arange images along y-axis
        image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
        save_image(image_grid,
                   "images/%s/%s.png" % (opt.dataset_name, batches_done),
                   normalize=False)

    # ----------
    #  Training
    # ----------
    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(
                np.ones((real_A.size(0), *D_A.output_shape))),
                             requires_grad=False)
            fake = Variable(Tensor(
                np.zeros((real_A.size(0), *D_A.output_shape))),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------

            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_GAN.item(),
                    loss_cycle.item(),
                    loss_identity.item(),
                    time_left,
                ))

            # If at sample interval save image
            if batches_done % opt.sample_interval == 0:
                sample_images(batches_done)

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(
                G_AB.state_dict(),
                "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                G_BA.state_dict(),
                "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_A.state_dict(),
                "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_B.state_dict(),
                "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
Esempio n. 11
0
        BAB = Gnet_AB(BA)
        L_cyc_BAB = L_cyc(BAB, real_B) * opt.lambda_

        #Total loss
        L_G = L_GAN_AB + L_GAN_BA + L_cyc_ABA + L_cyc_BAB + L_identity_AA + L_identity_BB

        L_G.backward()

        optim_G.step()

        #discriminator A
        A_realA = Dnet_A(real_A)
        loss_D_real = L_GAN(A_realA, real_label.expand_as(A_realA))

        BA_from_Buffer = Buffer_A.push_and_pop(BA)
        pred_fake = Dnet_A(BA_from_Buffer.detach())
        loss_D_fake = L_GAN(pred_fake, fake_label.expand_as(pred_fake))

        loss_D_A = (loss_D_real + loss_D_fake) * 0.5

        #discriminator B
        B_realB = Dnet_B(real_B)
        loss_D_real = L_GAN(B_realB, real_label.expand_as(B_realB))

        AB_from_Buffer = Buffer_B.push_and_pop(AB)
        pred_fake = Dnet_B(AB_from_Buffer.detach())
        loss_D_fake = L_GAN(pred_fake, fake_label.expand_as(pred_fake))

        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
class experiment():
    def __init__(self,
                 epoch=0,
                 n_epochs=1000,
                 batchSize=1,
                 lr=0.0002,
                 decay_epoch=100,
                 size=256,
                 input_nc=3,
                 output_nc=3,
                 cuda=True,
                 n_cpu=8,
                 load_from_ckpt=False):

        self.epoch = epoch
        self.n_epochs = n_epochs
        self.batchSize = batchSize
        self.lr = lr
        self.decay_epoch = decay_epoch
        self.size = size
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.cuda = cuda
        self.n_cpu = n_cpu

        rootA = "../dataset/monet_field_data"
        rootB = "../dataset/field_data"

        if torch.cuda.is_available() and not self.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )

        ###### Definition of variables ######
        # Networks
        self.netG_A2B = Generator(self.input_nc, self.output_nc)
        self.netG_B2A = Generator(self.output_nc, self.input_nc)
        self.netD_A = Discriminator(self.input_nc)
        self.netD_B = Discriminator(self.output_nc)

        if load_from_ckpt:
            print("loading from ckpt")
            self.netG_A2B.load_state_dict(torch.load('output/netG_A2B.pth'))
            self.netG_B2A.load_state_dict(torch.load('output/netG_B2A.pth'))
            self.netD_A.load_state_dict(torch.load('output/netD_A.pth'))
            self.netD_B.load_state_dict(torch.load('output/netD_B.pth'))
        else:
            self.netG_A2B.apply(weights_init_normal)
            self.netG_B2A.apply(weights_init_normal)
            self.netD_A.apply(weights_init_normal)
            self.netD_B.apply(weights_init_normal)

        if self.cuda:
            self.netG_A2B.cuda()
            self.netG_B2A.cuda()
            self.netD_A.cuda()
            self.netD_B.cuda()

        # Lossess
        self.criterion_GAN = torch.nn.MSELoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()

        # Optimizers & LR schedulers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A2B.parameters(), self.netG_B2A.parameters()),
                                            lr=self.lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                              lr=self.lr,
                                              betas=(0.5, 0.999))
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                              lr=self.lr,
                                              betas=(0.5, 0.999))

        self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=LambdaLR(self.n_epochs, self.epoch,
                               self.decay_epoch).step)
        self.lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D_A,
            lr_lambda=LambdaLR(self.n_epochs, self.epoch,
                               self.decay_epoch).step)
        self.lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D_B,
            lr_lambda=LambdaLR(self.n_epochs, self.epoch,
                               self.decay_epoch).step)

        if load_from_ckpt:
            print('load states')
            checkpoint = torch.load('output/states.pth')
            '''
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A'])
            self.optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B'])
            
            self.lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
            self.lr_scheduler_D_A.load_state_dict(checkpoint['lr_scheduler_D_A'])
            self.lr_scheduler_D_B.load_state_dict(checkpoint['lr_scheduler_D_B'])
            '''

            self.lr = checkpoint['lr']
            self.epoch = checkpoint['epoch'] + 1

        # Inputs & targets memory allocation
        Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor
        self.input_A = Tensor(self.batchSize, self.input_nc, self.size,
                              self.size)
        self.input_B = Tensor(self.batchSize, self.output_nc, self.size,
                              self.size)
        self.target_real = Variable(Tensor(self.batchSize).fill_(1.0),
                                    requires_grad=False)
        self.target_fake = Variable(Tensor(self.batchSize).fill_(0.0),
                                    requires_grad=False)

        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()

        # Dataset loader
        transforms_ = [
            transforms.Resize((int(self.size * 1.12), int(self.size * 1.12)),
                              Image.BICUBIC),
            transforms.RandomCrop(self.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        self.dataloader = DataLoader(ImageDataset(rootA,
                                                  rootB,
                                                  transforms_=transforms_,
                                                  unaligned=True),
                                     batch_size=self.batchSize,
                                     shuffle=True,
                                     num_workers=self.n_cpu)

        # Loss plot
        #logger = Logger(self.n_epochs, len(dataloader))
        ###################################

    def train(self):
        ###### Training ######
        for epoch in range(self.epoch, self.n_epochs):
            for i, batch in enumerate(self.dataloader):

                # Set model input
                real_A = Variable(self.input_A.copy_(batch['A']))
                real_B = Variable(self.input_B.copy_(batch['B']))

                ###### Generators A2B and B2A ######
                self.optimizer_G.zero_grad()

                # Identity loss
                # G_A2B(B) should equal B if real B is fed
                same_B = self.netG_A2B(real_B)
                loss_identity_B = self.criterion_identity(same_B, real_B) * 5.0
                # G_B2A(A) should equal A if real A is fed
                same_A = self.netG_B2A(real_A)
                loss_identity_A = self.criterion_identity(same_A, real_A) * 5.0

                # GAN loss
                fake_B = self.netG_A2B(real_A)
                pred_fake = self.netD_B(fake_B)
                loss_GAN_A2B = self.criterion_GAN(pred_fake, self.target_real)

                fake_A = self.netG_B2A(real_B)
                pred_fake = self.netD_A(fake_A)
                loss_GAN_B2A = self.criterion_GAN(pred_fake, self.target_real)

                # Cycle loss
                recovered_A = self.netG_B2A(fake_B)
                loss_cycle_ABA = self.criterion_cycle(recovered_A,
                                                      real_A) * 10.0

                recovered_B = self.netG_A2B(fake_A)
                loss_cycle_BAB = self.criterion_cycle(recovered_B,
                                                      real_B) * 10.0

                # Total loss
                loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
                loss_G.backward()

                self.optimizer_G.step()
                ###################################

                ###### Discriminator A ######
                self.optimizer_D_A.zero_grad()

                # Real loss
                pred_real = self.netD_A(real_A)
                loss_D_real = self.criterion_GAN(pred_real, self.target_real)

                # Fake loss
                fake_A = self.fake_A_buffer.push_and_pop(fake_A)
                pred_fake = self.netD_A(fake_A.detach())
                loss_D_fake = self.criterion_GAN(pred_fake, self.target_fake)

                # Total loss
                loss_D_A = (loss_D_real + loss_D_fake) * 0.5
                loss_D_A.backward()

                self.optimizer_D_A.step()
                ###################################

                ###### Discriminator B ######
                self.optimizer_D_B.zero_grad()

                # Real loss
                pred_real = self.netD_B(real_B)
                loss_D_real = self.criterion_GAN(pred_real, self.target_real)

                # Fake loss
                fake_B = self.fake_B_buffer.push_and_pop(fake_B)
                pred_fake = self.netD_B(fake_B.detach())
                loss_D_fake = self.criterion_GAN(pred_fake, self.target_fake)

                # Total loss
                loss_D_B = (loss_D_real + loss_D_fake) * 0.5
                loss_D_B.backward()

                self.optimizer_D_B.step()

                ###################################
                if i % 100 == 0:
                    text = [
                        strftime("%Y-%m-%d %H:%M:%S", gmtime()),
                        "epoch:{} batch_id:{}".format(epoch, i),
                        "loss_g: {:.4f}".format(loss_G),
                        "loss_DA: {:.4f} loss_DB: {:.4f}".format(
                            loss_D_A, loss_D_B)
                    ]
                    with open('logs.txt', 'a') as f:
                        for t in text:
                            print(t)
                            f.write(t + '\n')

                # Progress report (http://localhost:8097)
                '''
                logger.log({'loss_G': loss_G, 'loss_G_identity': (loss_identity_A + loss_identity_B), 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A),
                            'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)}, 
                            images={'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B})
                '''

            # test
            # self.test(self.netG_A2B, self.netG_B2A, epoch)

            # Update learning rates
            self.lr_scheduler_G.step()
            self.lr_scheduler_D_A.step()
            self.lr_scheduler_D_B.step()

            # Save models checkpoints
            torch.save(self.netG_A2B.state_dict(), 'output/netG_A2B.pth')
            torch.save(self.netG_B2A.state_dict(), 'output/netG_B2A.pth')
            torch.save(self.netD_A.state_dict(), 'output/netD_A.pth')
            torch.save(self.netD_B.state_dict(), 'output/netD_B.pth')
            torch.save({'epoch': epoch, 'lr': self.lr}, 'output/states.pth')

        ###################################

    def test(self, rootA, rootB, netG_A2B_path, netG_B2A_path, target_A,
             target_B):
        '''
        rootA = "../dataset/monet_field_data"
        rootB = "../dataset/landscape_test"
        '''

        ###### Definition of variables ######
        # Networks
        netG_A2B = Generator(self.input_nc, self.output_nc)
        netG_B2A = Generator(self.output_nc, self.input_nc)

        if self.cuda:
            netG_A2B.cuda()
            netG_B2A.cuda()

        # Load state dicts
        netG_A2B.load_state_dict(torch.load(netG_A2B_path))
        netG_B2A.load_state_dict(torch.load(netG_B2A_path))

        # Set model's test mode
        netG_A2B.eval()
        netG_B2A.eval()

        # Inputs & targets memory allocation
        Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor
        input_A = Tensor(1, self.input_nc, self.size, self.size)
        input_B = Tensor(1, self.output_nc, self.size, self.size)

        # Dataset loader
        transforms_ = [
            transforms.Resize([self.size, self.size], Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        dataloader = DataLoader(ImageDataset(rootA,
                                             rootB,
                                             transforms_=transforms_,
                                             unaligned=False),
                                batch_size=1,
                                shuffle=False,
                                num_workers=self.n_cpu)

        ###### Testing######

        # Create output dirs if they don't exist
        if not os.path.exists(target_A):
            os.makedirs(target_A)
        if not os.path.exists(target_B):
            os.makedirs(target_B)

        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            # Save image files
            save_image(0.5 * (real_A + 1.0),
                       target_A + '/real{}.png'.format(i))
            save_image(0.5 * (real_B + 1.0),
                       target_B + '/real{}.png'.format(i))

            # Generate output
            fake_B = 0.5 * (netG_A2B(real_A).data + 1.0)
            fake_A = 0.5 * (netG_B2A(real_B).data + 1.0)

            # Save image files
            save_image(fake_A, target_A + '/fake{}.png'.format(i))
            save_image(fake_B, target_B + '/fake{}.png'.format(i))

            sys.stdout.write('\rGenerated images %04d of %04d' %
                             (i + 1, len(dataloader)))

        sys.stdout.write('\n')
Esempio n. 13
0
def trainer(options):
	startEpoch = options["epoch"]
	nEpochs = options["nEpochs"]
	decayEpoch = options["decayEpoch"]
	assert(nEpochs>decayEpoch), "The decay epoch is larger than total epochs, There will be no decay :P, Sure?"

	(GEN_AtoB, DIS_B), (GEN_BtoA, DIS_A) = CycleGANmapper(inC, outC, options)

	# LOSSES 
	GANLoss = torch.nn.MSELoss() # ImageA to ImageB distance
	CycleLoss = torch.nn.L1Loss() # ImageA->ImageB->ImageA' distance
	IdentityLoss  = torch.nn.L1Loss() # Absolute loss

	optim_GAN = torch.optim.Adam(list(GEN_AtoB.parameters())+list(GEN_BtoA.parameters()), lr=options["learningrate"], betas=[0.5, 0.999])
	optim_DIS_A = torch.optim.Adam(DIS_A.parameters(), lr=options["learningrate"], betas=[0.5, 0.999])
	optim_DIS_B = torch.optim.Adam(DIS_B.parameters(), lr=options["learningrate"], betas=[0.5, 0.999])

	# LR Scheduler should be here
	lr_scheduler = lambda ep: 1.0 - max(0, ep+startEpoch)/(nEpochs - decayEpoch)

	lr_scheduler_GAN = torch.optim.lr_scheduler.LambdaLR(optim_GAN, lr_lambda=lr_scheduler)
	lr_scheduler_Dis_A = torch.optim.lr_scheduler.LambdaLR(optim_DIS_A, lr_lambda=lr_scheduler)
	lr_scheduler_Dis_B = torch.optim.lr_scheduler.LambdaLR(optim_DIS_B, lr_lambda=lr_scheduler)

	# Tensors Memory Allocation
	batchsize = options["batchsize"]
	if options["cuda"]:
		Tensor = torch.cuda.FloatTensor
	else:
		Tensor = torch.Tensor
	imageA = Tensor(batchsize, inC, inH, inW)
	imageB = Tensor(batchsize, outC, outH, outW)
	targetReal = Variable(Tensor(batchsize).fill_(1.0), requires_grad=False)
	targetFake = Variable(Tensor(batchsize).fill_(0.0), requires_grad=False)

	FakeAHolder = ReplayBuffer() # Check this Check this
	FakeBHolder = ReplayBuffer()

	dataloader = LoadData(options["datapath"])
	logger = Logger(nEpochs, len(dataloader))

	#Actual Training
	for epoch in range(startEpoch, nEpochs):
		for batch_id, batch_data in enumerate(dataloader):
			realA = Variable(imageA.copy_(batch_data['A']))
			realB = Variable(imageB.copy_(batch_data['B']))

			# Generator GEN_AtoB mapping from A to B and GEN_BtoA mapping from B to A
			optim_GAN.zero_grad()

			#Identity Loss: GEN_AtoB(realB) should generate B
			synthB = GEN_AtoB(realB)
			lossIden_BtoB = IdentityLoss(synthB, realB)*5.0

			#Identity Loss: GEN_BtoA(realA) should generate A
			synthA = GEN_BtoA(realA)
			lossIden_AtoA = IdentityLoss(synthA, realA)*5.0

			#GAN Loss: DIS_B(GEN_AtoB(realA)) should be closest to real target.
			fakeB = GEN_AtoB(realA)
			classB = DIS_B(fakeB)
			lossGEN_A2B = GANLoss(classB, targetReal)

			#GAN Loss: DIS_A(GEN_BtoA(realB)) should be closest to real target.
			fakeA = GEN_BtoA(realB)
			classA = DIS_A(fakeA)
			lossGEN_B2A = GANLoss(classA, targetReal)

			#Cycle Recontruction: GEN_BtoA(GEN_AtoB(realA)) -> realA should give realA
			reconA = GEN_BtoA(fakeB)
			lossCycle_ABA = CycleLoss(reconA, realA)*10.0

			#Cycle Recontruction: GEN_AtoB(GEN_BtoA(realB)) -> realB should give realA
			reconB = GEN_BtoA(fakeA)
			lossCycle_BAB = CycleLoss(reconB, realB)*10.0

			# Total Loss of the GANSs, Cycle Consistancy Loss
			lossTotal = lossCycle_BAB + lossCycle_ABA + lossGEN_B2A + lossGEN_A2B + lossIden_AtoA + lossIden_BtoB
			lossTotal.backward()

			optim_GAN.step()

			# Discriminator A Updatation part
			optim_DIS_A.zero_grad()

			#Real Loss: When a realA is sent into the Discriminator should predict 1.0
			classAreal = DIS_A(realA)
			lossDreal = GANLoss(classAreal, targetReal)

			#Fake Loss: When a fakeA is sent into the Discriminator should predict 0.0
			fakeA = FakeAHolder.push_and_pop(fakeA) # For logging
			classAfake = DIS_A(fakeA.detach())
			lossDfake = GANLoss(classAfake, targetFake)

			# The Discriminator should perfrom equally good at both the tasks
			lossDA = (lossDreal + lossDfake)/2

			lossDA.backward()
			optim_DIS_A.step()

			# Discriminator B Updatation part
			optim_DIS_B.zero_grad()

			#Real Loss: When a realB is sent into the Discriminator should predict 1.0
			classBreal = DIS_B(realB)
			lossDreal = GANLoss(classBreal, targetReal)

			#Fake Loss: When a fakeB is sent into the Discriminator should predict 0.0
			fakeB = FakeBHolder.push_and_pop(fakeB) # For logging
			classBfake = DIS_B(fakeB.detach())
			lossDfake = GANLoss(classBfake, targetFake)

			# The Discriminator should perfrom equally good at both the tasks
			lossDB = (lossDreal + lossDfake)/2

			lossDB.backward()
			optim_DIS_B.step()
			#Progress
			logger.log({'lossTotal': lossTotal, 'lossIdentity': (lossIden_AtoA + lossIden_BtoB), 'lossGAN': (lossGEN_A2B + lossGEN_B2A),
					'lossCycle': (lossCycle_ABA + lossCycle_BAB), 'lossD': (lossDA + lossDB)}, 
					images={'realA': realA, 'realB': realB, 'fakeA': fakeA, 'fakeB': fakeB})

		#Update learning rates
		lr_scheduler_GAN.step()
		lr_scheduler_Dis_A.step()
		lr_scheduler_Dis_B.step()

		# Save Models checkpoints
		SaveModels(GEN_AtoB, DIS_B, GEN_BtoA, DIS_A)
		with open("EpochVerify.txt",'w') as ff:
			ff.write("\n"+str(epoch))
Esempio n. 14
0
def train(config):
    ## set pre-process
    prep_dict = {}
    prep_config = config["prep"]
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])
    if prep_config["test_10crop"]:
        prep_dict["test"] = prep.image_test_10crop(**config["prep"]['params'])
    else:
        prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    if prep_config["test_10crop"]:
        for i in range(10):
            dsets["test"] = [ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                       transform=prep_dict["test"][i]) for i in range(10)]
            dset_loaders["test"] = [DataLoader(dset, batch_size=test_bs, \
                                               shuffle=False, num_workers=0) for dset in dsets['test']]
    else:
        dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                                  transform=prep_dict["test"])
        dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                          shuffle=False, num_workers=0)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    # base_network = base_network.cuda()

    ## 添加判别器D_s,D_t,生成器G_s2t,G_t2s

    z_dimension = 256
    D_s = network.models["Discriminator"]()
    # D_s = D_s.cuda()
    G_s2t = network.models["Generator"](z_dimension, 1024)
    # G_s2t = G_s2t.cuda()

    D_t = network.models["Discriminator"]()
    # D_t = D_t.cuda()
    G_t2s = network.models["Generator"](z_dimension, 1024)
    # G_t2s = G_t2s.cuda()

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion_Sem = torch.nn.L1Loss()

    optimizer_G = torch.optim.Adam(itertools.chain(G_s2t.parameters(), G_t2s.parameters()), lr=0.0003)
    optimizer_D_s = torch.optim.Adam(D_s.parameters(), lr=0.0003)
    optimizer_D_t = torch.optim.Adam(D_t.parameters(), lr=0.0003)

    fake_S_buffer = ReplayBuffer()
    fake_T_buffer = ReplayBuffer()

    classifier_optimizer = torch.optim.Adam(base_network.parameters(), lr=0.0003)
    ## 添加分类器
    classifier1 = net.Net(256,class_num)
    # classifier1 = classifier1.cuda()
    classifier1_optim = optim.Adam(classifier1.parameters(), lr=0.0003)

    ## add additional network for some methods
    if config["loss"]["random"]:
        random_layer = network.RandomLayer([base_network.output_num(), class_num], config["loss"]["random_dim"])
        ad_net = network.AdversarialNetwork(config["loss"]["random_dim"], 1024)
    else:
        random_layer = None
        ad_net = network.AdversarialNetwork(base_network.output_num() * class_num, 1024)
    if config["loss"]["random"]:
        random_layer.cuda()
    # ad_net = ad_net.cuda()
    parameter_list = base_network.get_parameters() + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer = optimizer_config["type"](parameter_list, \
                                         **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus])
        base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus])

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    best_acc = 0.0
    for i in range(config["num_iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders, \
                                                 base_network, test_10crop=prep_config["test_10crop"])
            temp_model = nn.Sequential(base_network)
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = temp_model

                now = datetime.datetime.now()
                d = str(now.month) + '-' + str(now.day) + ' ' + str(now.hour) + ':' + str(now.minute) + ":" + str(
                    now.second)
                torch.save(best_model, osp.join(config["output_path"],
                                                "{}_to_{}_best_model_acc-{}_{}.pth.tar".format(args.source, args.target,
                                                                                               best_acc, d)))
            log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()

            print(log_str)
        if i % config["snapshot_interval"] == 0:
            torch.save(nn.Sequential(base_network), osp.join(config["output_path"], \
                                                             "{}_to_{}_iter_{:05d}_model_{}.pth.tar".format(args.source,
                                                                                                            args.target,
                                                                                                            i, str(
                                                                     datetime.datetime.utcnow()))))
        print("it_train: {:05d} / {:05d} start".format(i, config["num_iterations"]))
        loss_params = config["loss"]
        ## train one iter
        classifier1.train(True)
        base_network.train(True)
        ad_net.train(True)
        optimizer = lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()


        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        # inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()

        # 提取特征
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        outputs = torch.cat((outputs_source, outputs_target), dim=0)
        softmax_out = nn.Softmax(dim=1)(outputs)

        outputs_source1 = classifier1(features_source.detach())
        outputs_target1 = classifier1(features_target.detach())
        outputs1 = torch.cat((outputs_source1,outputs_target1),dim=0)
        softmax_out1 = nn.Softmax(dim=1)(outputs1)

        softmax_out = (1-args.cla_plus_weight)*softmax_out + args.cla_plus_weight*softmax_out1

        if config['method'] == 'CDAN+E':
            entropy = loss.Entropy(softmax_out)
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
        elif config['method'] == 'CDAN':
            transfer_loss = loss.CDAN([features, softmax_out], ad_net, None, None, random_layer)
        elif config['method'] == 'DANN':
            transfer_loss = loss.DANN(features, ad_net)
        else:
            raise ValueError('Method cannot be recognized.')
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)

        # Cycle
        num_feature = features_source.size(0)
        # =================train discriminator T
        real_label = Variable(torch.ones(num_feature))
        # real_label = Variable(torch.ones(num_feature)).cuda()
        fake_label = Variable(torch.zeros(num_feature))
        # fake_label = Variable(torch.zeros(num_feature)).cuda()

        # 训练生成器
        optimizer_G.zero_grad()

        # Identity loss
        same_t = G_s2t(features_target.detach())
        loss_identity_t = criterion_identity(same_t, features_target)

        same_s = G_t2s(features_source.detach())
        loss_identity_s = criterion_identity(same_s, features_source)

        # Gan loss
        fake_t = G_s2t(features_source.detach())
        pred_fake = D_t(fake_t)
        loss_G_s2t = criterion_GAN(pred_fake, labels_source.float())

        fake_s = G_t2s(features_target.detach())
        pred_fake = D_s(fake_s)
        loss_G_t2s = criterion_GAN(pred_fake, labels_source.float())

        # cycle loss
        recovered_s = G_t2s(fake_t)
        loss_cycle_sts = criterion_cycle(recovered_s, features_source)

        recovered_t = G_s2t(fake_s)
        loss_cycle_tst = criterion_cycle(recovered_t, features_target)

        # sem loss
        pred_recovered_s = base_network.fc(recovered_s)
        pred_fake_t = base_network.fc(fake_t)
        loss_sem_t2s = criterion_Sem(pred_recovered_s, pred_fake_t)

        pred_recovered_t = base_network.fc(recovered_t)
        pred_fake_s = base_network.fc(fake_s)
        loss_sem_s2t = criterion_Sem(pred_recovered_t, pred_fake_s)

        loss_cycle = loss_cycle_tst + loss_cycle_sts
        weights = args.weight_in_lossG.split(',')
        loss_G = float(weights[0]) * (loss_identity_s + loss_identity_t) + \
                 float(weights[1]) * (loss_G_s2t + loss_G_t2s) + \
                 float(weights[2]) * loss_cycle + \
                 float(weights[3]) * (loss_sem_s2t + loss_sem_t2s)



        # 训练softmax分类器
        outputs_fake = classifier1(fake_t.detach())
        # 分类器优化
        classifier_loss1 = nn.CrossEntropyLoss()(outputs_fake, labels_source)
        classifier1_optim.zero_grad()
        classifier_loss1.backward()
        classifier1_optim.step()

        total_loss = loss_params["trade_off"] * transfer_loss + classifier_loss + args.cyc_loss_weight*loss_G
        total_loss.backward()
        optimizer.step()
        optimizer_G.step()

        ###### Discriminator S ######
        optimizer_D_s.zero_grad()

        # Real loss
        pred_real = D_s(features_source.detach())
        loss_D_real = criterion_GAN(pred_real, real_label)

        # Fake loss
        fake_s = fake_S_buffer.push_and_pop(fake_s)
        pred_fake = D_s(fake_s.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake_label)

        # Total loss
        loss_D_s = loss_D_real + loss_D_fake
        loss_D_s.backward()

        optimizer_D_s.step()
        ###################################

        ###### Discriminator t ######
        optimizer_D_t.zero_grad()

        # Real loss
        pred_real = D_t(features_target.detach())
        loss_D_real = criterion_GAN(pred_real, real_label)

        # Fake loss
        fake_t = fake_T_buffer.push_and_pop(fake_t)
        pred_fake = D_t(fake_t.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake_label)

        # Total loss
        loss_D_t = loss_D_real + loss_D_fake
        loss_D_t.backward()
        optimizer_D_t.step()
        print("it_train: {:05d} / {:05d} over".format(i, config["num_iterations"]))
    now = datetime.datetime.now()
    d = str(now.month)+'-'+str(now.day)+' '+str(now.hour)+':'+str(now.minute)+":"+str(now.second)
    torch.save(best_model, osp.join(config["output_path"],
                                    "{}_to_{}_best_model_acc-{}_{}.pth.tar".format(args.source, args.target,
                                                                            best_acc,d)))
    return best_acc
def main(args):
    writer = SummaryWriter(os.path.join(args.out_dir, 'logs'))
    current_time = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
    os.makedirs(
        os.path.join(args.out_dir, 'models',
                     args.model_name + '_' + current_time))
    os.makedirs(
        os.path.join(args.out_dir, 'logs',
                     args.model_name + '_' + current_time))

    G_AB = Generator(args.in_channel, args.out_channel).to(args.device)
    G_BA = Generator(args.in_channel, args.out_channel).to(args.device)
    D_A = Discriminator(args.in_channel).to(args.device)
    D_B = Discriminator(args.out_channel).to(args.device)
    segmen_B = Unet(3, 34).to(args.device)

    if args.model_path is not None:
        AB_path = os.join.path(args.model_path, 'ab.pt')
        BA_path = os.join.path(args.model_path, 'ba.pt')
        DA_path = os.join.path(args.model_path, 'da.pt')
        DB_path = os.join.path(args.model_path, 'db.pt')
        segmen_path = os.join.path(args.model_path, 'semsg.pt')

        with open(AB_path, 'rb') as f:
            state_dict = torch.load(f)
            G_AB.load_state_dict(state_dict)

        with open(BA_path, 'rb') as f:
            state_dict = torch.load(f)
            G_BA.load_state_dict(state_dict)

        with open(DA_path, 'rb') as f:
            state_dict = torch.load(f)
            D_A.load_state_dict(state_dict)

        with open(DB_path, 'rb') as f:
            state_dict = torch.load(f)
            D_B.load_state_dict(state_dict)

        with open(segmen_path, 'rb') as f:
            state_dict = torch.load(f)
            segmen_B.load_state_dict(state_dict)

    else:
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

    G_AB = nn.DataParallel(G_AB)
    G_BA = nn.DataParallel(G_BA)
    D_A = nn.DataParallel(D_A)
    D_B = nn.DataParallel(D_B)
    segmen_B = nn.DataParallel(segmen_B)

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion_segmen = torch.nn.BCELoss()

    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=args.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))

    optimizer_segmen_B = torch.optim.Adam(segmen_B.parameters(),
                                          lr=args.lr,
                                          betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(args.n_epochs, args.epoch, args.decay_epoch).step)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    transforms_ = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    dataloader = DataLoader(ImgDataset(args.dataset_path,
                                       transforms_=transforms_,
                                       unaligned=True,
                                       device=args.device),
                            batch_size=args.batchSize,
                            shuffle=True,
                            num_workers=0)
    logger = Logger(args.n_epochs, len(dataloader))
    target_real = Variable(torch.Tensor(args.batchSize,
                                        1).fill_(1.)).to(args.device).detach()
    target_fake = Variable(torch.Tensor(args.batchSize,
                                        1).fill_(0.)).to(args.device).detach()

    G_AB.train()
    G_BA.train()
    D_A.train()
    D_B.train()
    segmen_B.train()

    for epoch in range(args.epoch, args.n_epochs):
        for i, batch in enumerate(dataloader):
            real_A = batch['A'].clone()
            real_B = batch['B'].clone()
            B_label = batch['B_label'].clone()

            fake_b = G_AB(real_A)
            fake_a = G_BA(real_B)
            same_b = G_AB(real_B)
            same_a = G_BA(real_A)
            recovered_A = G_BA(fake_b)
            recovered_B = G_AB(fake_a)
            pred_Blabel = segmen_B(real_B)
            pred_fakeAlabel = segmen_B(fake_a)

            optimizer_segmen_B.zero_grad()
            #segmen loss, do we assume that it also learns how to segment images after doing domain transfer?
            loss_segmen_B = criterion_segmen(
                pred_Blabel, B_label) + criterion_segmen(
                    segmen_B(fake_a.detach()), B_label)
            loss_segmen_B.backward()
            optimizer_segmen_B.step()

            optimizer_G.zero_grad()
            #gan loss
            pred_fakeb = D_B(fake_b)
            loss_gan_AB = criterion_GAN(pred_fakeb, target_real)

            pred_fakea = D_A(fake_a)
            loss_gan_BA = criterion_GAN(pred_fakea, target_real)

            #identity loss
            loss_identity_B = criterion_identity(same_b, real_B) * 5
            loss_identity_A = criterion_identity(same_a, real_A) * 5

            #cycle consistency loss
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10

            #cycle segmen diff loss
            loss_segmen_diff = criterion_segmen(segmen_B(recovered_B),
                                                pred_Blabel.detach())

            loss_G = loss_gan_AB + loss_gan_BA + loss_identity_B + loss_identity_A + loss_cycle_ABA + loss_cycle_BAB + loss_segmen_diff
            loss_G.backward()

            optimizer_G.step()

            ##discriminator a
            optimizer_D_A.zero_grad()

            pred_realA = D_A(real_A)
            loss_D_A_real = criterion_GAN(pred_realA, target_real)

            fake_A = fake_A_buffer.push_and_pop(fake_a)
            pred_fakeA = D_A(fake_A.detach())
            loss_D_A_fake = criterion_GAN(pred_fakeA, target_fake)

            loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()

            #discriminator b
            optimizer_D_B.zero_grad()

            pred_realB = D_B(real_B)
            loss_D_B_real = criterion_GAN(pred_realB, target_real)

            fake_B = fake_B_buffer.push_and_pop(fake_b)
            pred_fakeB = D_B(fake_B.detach())
            loss_D_B_fake = criterion_GAN(pred_fakeB, target_fake)

            loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()

            logger.log(
                {
                    'loss_segmen_B': loss_segmen_B,
                    'loss_G': loss_G,
                    'loss_G_identity': (loss_identity_A + loss_identity_B),
                    'loss_G_GAN': (loss_gan_AB + loss_gan_BA),
                    'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB),
                    'loss_D': (loss_D_A + loss_D_B)
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_a,
                    'fake_B': fake_b,
                    'reconstructed_A': recovered_A,
                    'reconstructed_B': recovered_B
                },
                out_dir=os.path.join(
                    args.out_dir, 'logs',
                    args.model_name + '_' + current_time + '/' + str(epoch)),
                writer=writer)

        if (epoch + 1) % args.save_per_epochs == 0:
            os.makedirs(
                os.path.join(args.out_dir, 'models',
                             args.model_name + '_' + current_time, str(epoch)))
            torch.save(
                G_AB.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'ab.pt'))
            torch.save(
                G_BA.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'ba.pt'))
            torch.save(
                D_A.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'da.pt'))
            torch.save(
                D_B.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'db.pt'))
            torch.save(
                segmen_B.module.state_dict(),
                os.path.join(args.out_dir,
                             'models', args.model_name + '_' + current_time,
                             str(epoch), 'semsg.pt'))

        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
Esempio n. 16
0
                    args.lambda_cyc * loss_cycle + \
                    args.lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

         # -----------------------
        #  Train Discriminator A
        # -----------------------
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = D__A(real_X_A)
        loss_real = criterion_GAN(pred_real, valid)
        # Fake loss (on batch of previously generated samples)
        fake_Y_A_ = fake_Y_A_buffer.push_and_pop(fake_Y_A)
        pred_fake = D__A(fake_Y_A_.detach())
        loss_fake = criterion_GAN(pred_fake, fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = D__B(real_Y_B)
Esempio n. 17
0
def main(args):
    torch.manual_seed(0)
    if args.mb_D:
        raise NotImplementedError('mb_D not implemented')
        assert args.batch_size > 1, 'batch size needs to be larger than 1 if mb_D'

    if args.img_norm != 'znorm':
        raise NotImplementedError('{} not implemented'.format(args.img_norm))

    assert args.act in ['relu', 'mish'], 'args.act = {}'.format(args.act)

    modelarch = 'C_{0}_{1}_{2}_{3}_{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}{14}{15}{16}{17}{18}{19}{20}{21}{22}'.format(
        args.size, args.batch_size, args.lr,  args.n_epochs, args.decay_epoch, # 0, 1, 2, 3, 4
        '_G' if args.G_extra else '',  # 5
        '_D' if args.D_extra else '',  # 6
        '_U' if args.upsample else '',  # 7
        '_S' if args.slow_D else '',  # 8
        '_RL{}-{}'.format(args.start_recon_loss_val, args.start_recon_loss_val),  # 9
        '_GL{}-{}'.format(args.start_gan_loss_val, args.start_gan_loss_val),  # 10
        '_prop' if args.keep_prop else '',  # 11
        '_' + args.img_norm,  # 12
        '_WL' if args.wasserstein else '',  # 13
        '_MBD' if args.mb_D else '',  # 14
        '_FM' if args.fm_loss else '',  # 15
        '_BF{}'.format(args.buffer_size) if args.buffer_size != 50 else '',  # 16
        '_N' if args.add_noise else '',  # 17
        '_L{}'.format(args.load_iter) if args.load_iter > 0 else '',  # 18
        '_res{}'.format(args.n_resnet_blocks),  # 19
        '_n{}'.format(args.data_subset) if args.data_subset is not None else '',  # 20
        '_{}'.format(args.optim),  # 21
        '_{}'.format(args.act))  # 22

    samples_path = os.path.join(args.output_dir, modelarch, 'samples')
    safe_mkdirs(samples_path)
    model_path = os.path.join(args.output_dir, modelarch, 'models')
    safe_mkdirs(model_path)
    test_path = os.path.join(args.output_dir, modelarch, 'test')
    safe_mkdirs(test_path)

    # Definition of variables ######
    # Networks
    netG_A2B = Generator(args.input_nc, args.output_nc, img_size=args.size,
                         extra_layer=args.G_extra, upsample=args.upsample,
                         keep_weights_proportional=args.keep_prop,
                         n_residual_blocks=args.n_resnet_blocks,
                         act=args.act)
    netG_B2A = Generator(args.output_nc, args.input_nc, img_size=args.size,
                         extra_layer=args.G_extra, upsample=args.upsample,
                         keep_weights_proportional=args.keep_prop,
                         n_residual_blocks=args.n_resnet_blocks,
                         act=args.act)
    netD_A = Discriminator(args.input_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size)
    netD_B = Discriminator(args.output_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size)

    if args.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    if args.load_iter == 0:
        netG_A2B.apply(weights_init_normal)
        netG_B2A.apply(weights_init_normal)
        netD_A.apply(weights_init_normal)
        netD_B.apply(weights_init_normal)
    else:
        netG_A2B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_A2B_{}.pth'.format(args.load_iter))))
        netG_B2A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_B2A_{}.pth'.format(args.load_iter))))
        netD_A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_A_{}.pth'.format(args.load_iter))))
        netD_B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_B_{}.pth'.format(args.load_iter))))

        netG_A2B.train()
        netG_B2A.train()
        netD_A.train()
        netD_B.train()

    # Lossess
    criterion_GAN = wasserstein_loss if args.wasserstein else torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    feat_criterion = torch.nn.HingeEmbeddingLoss()

    # I could also update D only if iters % 2 == 0
    lr_G = args.lr
    lr_D = args.lr / 2 if args.slow_D else args.lr

    # Optimizers & LR schedulers
    if args.optim == 'adam':
        optim = torch.optim.Adam
    elif args.optim == 'radam':
        optim = RAdam
    elif args.optim == 'ranger':
        optim = Ranger
    elif args.optim == 'rangerlars':
        optim = RangerLars
    else:
        raise NotImplementedError('args.optim = {} not implemented'.format(args.optim))

    optimizer_G = optim(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                        lr=args.lr, betas=(0.5, 0.999))
    optimizer_D_A = optim(netD_A.parameters(), lr=lr_G, betas=(0.5, 0.999))
    optimizer_D_B = optim(netD_B.parameters(), lr=lr_D, betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if args.cuda else torch.Tensor
    input_A = Tensor(args.batch_size, args.input_nc, args.size, args.size)
    input_B = Tensor(args.batch_size, args.output_nc, args.size, args.size)
    target_real = Variable(Tensor(args.batch_size).fill_(1.0), requires_grad=False)
    target_fake = Variable(Tensor(args.batch_size).fill_(0.0), requires_grad=False)

    fake_A_buffer = ReplayBuffer(args.buffer_size)
    fake_B_buffer = ReplayBuffer(args.buffer_size)

    # Transforms and dataloader for training set
    transforms_ = []
    if args.resize_crop:
        transforms_ += [transforms.Resize(int(args.size*1.12), Image.BICUBIC),
                        transforms.RandomCrop(args.size)]
    else:
        transforms_ += [transforms.Resize(args.size, Image.BICUBIC)]

    if args.horizontal_flip:
        transforms_ += [transforms.RandomHorizontalFlip()]

    transforms_ += [transforms.ToTensor()]

    if args.add_noise:
        transforms_ += [transforms.Lambda(lambda x: x + torch.randn_like(x))]

    transforms_norm = []
    if args.img_norm == 'znorm':
        transforms_norm += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    elif 'scale01' in args.img_norm:
        transforms_norm += [transforms.Lambda(lambda x: x.mul(1/255))]  # TODO this might not preserve the dimensions. is .mul per element?
        if 'flip' in args.img_norm:
            transforms_norm += [transforms.Lambda(lambda x: (x - 1).abs())]  # TODO this might not preserve the dimensions. is .mul per element?
    else:
        raise ValueError('wrong --img_norm. only znorm|scale01|scale01flip')

    transforms_ += transforms_norm

    dataloader = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_, unaligned=True, n=args.data_subset),
                            batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu)

    # Transforms and dataloader for test set
    transforms_test_ = [transforms.Resize(args.size, Image.BICUBIC),
                        transforms.ToTensor()]
    transforms_test_ += transforms_norm

    dataloader_test = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_test_, mode='test'),
                                 batch_size=args.batch_size, shuffle=False, num_workers=args.n_cpu)
    # Training ######
    if args.load_iter == 0 and args.load_epoch != 0:
        print('****** NOTE: args.load_iter == 0 and args.load_epoch != 0 ******')

    iter = args.load_iter
    prev_time = time.time()
    n_test = 10e10 if args.n_test is None else args.n_test
    n_sample = 10e10 if args.n_sample is None else args.n_sample

    rl_delta_x = args.n_epochs - args.recon_loss_epoch
    rl_delta_y = args.end_recon_loss_val - args.start_recon_loss_val

    gan_delta_x = args.n_epochs - args.gan_loss_epoch
    gan_delta_y = args.end_gan_loss_val - args.start_gan_loss_val

    for epoch in range(args.load_epoch, args.n_epochs):

        rl_effective_epoch = max(epoch - args.recon_loss_epoch, 0)
        recon_loss_rate = args.start_recon_loss_val + rl_effective_epoch * (rl_delta_y / rl_delta_x)

        gan_effective_epoch = max(epoch - args.gan_loss_epoch, 0)
        gan_loss_rate = args.start_gan_loss_val + gan_effective_epoch * (gan_delta_y / gan_delta_x)

        id_loss_rate = 5.0

        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            # Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake, _ = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

            fake_A = netG_B2A(real_B)
            pred_fake, _ = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)

            # Total loss
            loss_G = (loss_identity_A + loss_identity_B) * id_loss_rate
            loss_G += (loss_GAN_A2B + loss_GAN_B2A) * gan_loss_rate
            loss_G += (loss_cycle_ABA + loss_cycle_BAB) * recon_loss_rate

            loss_G.backward()

            optimizer_G.step()

            # Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real, _ = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake, _ = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            loss_D_A = (loss_D_real + loss_D_fake) * 0.5

            if args.fm_loss:
                pred_real, feats_real = netD_A(real_A)
                pred_fake, feats_fake = netD_A(fake_A.detach())

                fm_loss_A = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda)

                loss_D_A = loss_D_A * 0.1 + fm_loss_A * 0.9

            loss_D_A.backward()

            optimizer_D_A.step()

            # Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real, _ = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake, _ = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            loss_D_B = (loss_D_real + loss_D_fake)*0.5

            if args.fm_loss:
                pred_real, feats_real = netD_B(real_B)
                pred_fake, feats_fake = netD_B(fake_B.detach())

                fm_loss_B = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda)

                loss_D_B = loss_D_B * 0.1 + fm_loss_B * 0.9

            loss_D_B.backward()

            optimizer_D_B.step()

            if iter % args.log_interval == 0:

                print('---------------------')
                print('GAN loss:', as_np(loss_GAN_A2B), as_np(loss_GAN_B2A))
                print('Identity loss:', as_np(loss_identity_A), as_np(loss_identity_B))
                print('Cycle loss:', as_np(loss_cycle_ABA), as_np(loss_cycle_BAB))
                print('D loss:', as_np(loss_D_A), as_np(loss_D_B))
                if args.fm_loss:
                    print('fm loss:', as_np(fm_loss_A), as_np(fm_loss_B))
                print('recon loss rate:', recon_loss_rate)
                print('time:', time.time() - prev_time)
                prev_time = time.time()

            if iter % args.plot_interval == 0:
                pass

            if iter % args.image_save_interval == 0:
                samples_path_ = os.path.join(samples_path, str(iter / args.image_save_interval))
                safe_mkdirs(samples_path_)

                # New savedir
                test_pth_AB = os.path.join(test_path, str(iter / args.image_save_interval), 'AB')
                test_pth_BA = os.path.join(test_path, str(iter / args.image_save_interval), 'BA')

                safe_mkdirs(test_pth_AB)
                safe_mkdirs(test_pth_BA)

                for j, batch_ in enumerate(dataloader_test):

                    real_A_test = Variable(input_A.copy_(batch_['A']))
                    real_B_test = Variable(input_B.copy_(batch_['B']))

                    fake_AB_test = netG_A2B(real_A_test)
                    fake_BA_test = netG_B2A(real_B_test)

                    if j < n_sample:
                        recovered_ABA_test = netG_B2A(fake_AB_test)
                        recovered_BAB_test = netG_A2B(fake_BA_test)

                        fn = os.path.join(samples_path_, str(j))
                        imageio.imwrite(fn + '.A.jpg', tensor2image(real_A_test[0], args.img_norm))
                        imageio.imwrite(fn + '.B.jpg', tensor2image(real_B_test[0], args.img_norm))
                        imageio.imwrite(fn + '.BA.jpg', tensor2image(fake_BA_test[0], args.img_norm))
                        imageio.imwrite(fn + '.AB.jpg', tensor2image(fake_AB_test[0], args.img_norm))
                        imageio.imwrite(fn + '.ABA.jpg', tensor2image(recovered_ABA_test[0], args.img_norm))
                        imageio.imwrite(fn + '.BAB.jpg', tensor2image(recovered_BAB_test[0], args.img_norm))

                    if j < n_test:
                        fn_A = os.path.basename(batch_['img_A'][0])
                        imageio.imwrite(os.path.join(test_pth_AB, fn_A), tensor2image(fake_AB_test[0], args.img_norm))

                        fn_B = os.path.basename(batch_['img_B'][0])
                        imageio.imwrite(os.path.join(test_pth_BA, fn_B), tensor2image(fake_BA_test[0], args.img_norm))

            if iter % args.model_save_interval == 0:
                # Save models checkpoints
                torch.save(netG_A2B.state_dict(), os.path.join(model_path, 'G_A2B_{}.pth'.format(iter)))
                torch.save(netG_B2A.state_dict(), os.path.join(model_path, 'G_B2A_{}.pth'.format(iter)))
                torch.save(netD_A.state_dict(), os.path.join(model_path, 'D_A_{}.pth'.format(iter)))
                torch.save(netD_B.state_dict(), os.path.join(model_path, 'D_B_{}.pth'.format(iter)))

            iter += 1

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
Esempio n. 18
0
    optimizerG.step()

    ############################
    # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
    ###########################
    optimizerD.zero_grad()

    # Real loss
    d_out, lables = netD(torch.cat([real_view_0, real_view_1]))
    loss_real = criterion_GAN(d_out, torch.cat([label_valid, label_valid]))
    # real lable loss
    label_loss = 0
    # for l in lab TODO2
    # Fake loss (on batch of previously generated samples)
    fake_img = torch.cat([decoded_fake_view_0, decoded_fake_view_1])
    fake_img = fake_buffer.push_and_pop(fake_img)
    d_out_fake = netD(fake_img.detach())[0]
    loss_fake = criterion_GAN(d_out_fake, torch.cat([label_fake, label_fake]))

    # Total loss
    loss_D = (loss_real + loss_fake) / 2

    loss_D.backward()
    optimizerD.step()

    log.info(
        '[%d/%d] Loss_D: %.4f, Loss_id: %.4f, D(fake)%.1f, cyclegan %.4f' %
        (step, opt.niter, loss_D.data[0], id_loss.data[0], d_out_fake.data[0],
         gan_loss.data[0]))

    #
Esempio n. 19
0
        # Total loss
        loss_G = loss_identity_S + loss_identity_T + loss_GAN_S2T + loss_GAN_T2S + loss_cycle_STS + loss_cycle_TST
        loss_G.backward()

        optimizer_G.step()
        ###################################

        ###### Discriminator S ######
        optimizer_D_S.zero_grad()

        # Real loss
        pred_real = netD_S(real_S)
        loss_D_real = criterion_GAN(pred_real, target_real.view(1, 1))

        # Fake loss
        fake_S = fake_S_buffer.push_and_pop(fake_S)
        pred_fake = netD_S(fake_S.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake.view(1, 1))

        # Total loss
        loss_D_S = (loss_D_real + loss_D_fake) * 0.5
        loss_D_S.backward()

        optimizer_D_S.step()
        ###################################

        ###### Discriminator T ######
        optimizer_D_T.zero_grad()

        # Real loss
        pred_real = netD_T(real_T)