Ejemplo n.º 1
0
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))

transform = transforms.Compose([
    transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataloader = DataLoader(
    ImageDataset(cfg_net['data_path'], transform=transform),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)

val_dataloader = DataLoader(
    ImageDataset(cfg_net['valid_path'], transform=transform),
    batch_size=10,
    shuffle=True,
    num_workers=1,
)

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
Ejemplo n.º 2
0
    T.Resize([384, 128]),
    T.RandomHorizontalFlip(p=0.5),
    T.Pad(10),
    T.RandomCrop([384, 128]),
    T.ToTensor(), normalize_transform,
    RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406])
])

# val_transforms = T.Compose([
#     T.Resize([384, 128]),
#     T.ToTensor(),
#     normalize_transform
# ])

dataset = init_dataset('market1501', root='../')
train_set = ImageDataset(dataset.train, train_transforms)
dataloaders['train'] = DataLoader(train_set,
                                  batch_size=opt.batchsize,
                                  drop_last=True,
                                  sampler=RandomIdentitySampler(
                                      dataset.train, opt.batchsize, 4),
                                  num_workers=8)

# val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
# dataloaders['val'] = DataLoader(
#     val_set, batch_size=opt.batchsize, drop_last=True, shuffle=False, num_workers=8)

######################################################################
# Training the model
# --------
#
target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [
    transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
    transforms.RandomCrop(opt.size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
dataloader = DataLoader(ImageDataset(opt.dataroot,
                                     transforms_=transforms_,
                                     unaligned=True),
                        batch_size=opt.batchSize,
                        shuffle=True,
                        num_workers=opt.n_cpu)

# Loss plot
logger = Logger(opt.n_epochs, len(dataloader))
###################################

###### Training ######
for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):
        # Set model input
        real_A = Variable(input_A.copy_(batch['A']))
        real_B = Variable(input_B.copy_(batch['B']))
def main():
    parser = argparse.ArgumentParser(description='model training')
    parser.add_argument('--save_dir',
                        type=str,
                        default='logs/tmp',
                        help='save model directory')
    # dataset
    parser.add_argument('--dataset_dir',
                        type=str,
                        default='datasets',
                        help='datasets path')
    parser.add_argument('--valid_pect',
                        type=float,
                        default=0.2,
                        help='validation percent split from train')
    parser.add_argument('--train_bs',
                        type=int,
                        default=64,
                        help='train images per batch')
    parser.add_argument('--test_bs',
                        type=int,
                        default=128,
                        help='test images per batch')
    # training
    parser.add_argument('--no_gpu',
                        action='store_true',
                        help='whether use gpu')
    parser.add_argument('--gpus',
                        type=str,
                        default='0',
                        help='gpus to use in training')
    parser.add_argument('--log_interval',
                        type=int,
                        default=20,
                        help='intermediate printing')
    parser.add_argument('--save_step',
                        type=int,
                        default=20,
                        help='save model every save_step')

    args = parser.parse_args()

    mkdir_if_missing(args.save_dir)
    log_path = os.path.join(args.save_dir, 'log.txt')
    with open(log_path, 'w') as f:
        f.write('{}'.format(args))

    device = "cuda:{}".format(args.gpus) if not args.no_gpu else "cpu"
    if not args.no_gpu:
        cudnn.benchmark = True

    # define train transforms and test transforms
    totensor = T.ToTensor()
    normalize = T.Normalize(mean=[0.491, 0.482, 0.446],
                            std=[0.202, 0.199, 0.201])
    train_tfms = list()
    train_tfms.append(T.RandomResizedCrop((224, 224)))
    train_tfms.append(T.RandomHorizontalFlip())
    train_tfms.append(totensor)
    train_tfms.append(normalize)
    train_tfms = T.Compose(train_tfms)

    test_tfms = list()
    test_tfms.append(T.Resize((224, 224)))
    test_tfms.append(totensor)
    test_tfms.append(normalize)
    test_tfms = T.Compose(test_tfms)

    # get dataloader
    train_list, valid_list, label2name = split_dataset(args.dataset_dir,
                                                       args.valid_pect)
    trainset = ImageDataset(train_list, train_tfms)
    validset = ImageDataset(valid_list, test_tfms)

    train_loader = DataLoader(trainset,
                              batch_size=args.train_bs,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    valid_loader = DataLoader(validset,
                              batch_size=args.test_bs,
                              shuffle=False,
                              num_workers=8,
                              pin_memory=True)

    # define network
    net = get_resnet50(len(label2name), pretrain=True)

    # define loss
    ce_loss = nn.CrossEntropyLoss()

    # base_params = list(net.parameters())[:-2]
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=5e-4)
    # define optimizer and lr scheduler
    # if args.opt_func == 'Adam':
    #     optimizer = getattr(torch.optim, args.opt_func)(net.parameters(), weight_decay=args.wd)
    # else:
    #     optimizer = getattr(torch.optim, args.opt_func)(net.parameters(), weight_decay=args.wd, momentum=args.momentum)

    train(
        args=args,
        network=net,
        train_data=train_loader,
        valid_data=valid_loader,
        optimizer=optimizer,
        criterion=ce_loss,
        device=device,
        log_path=log_path,
        label2name=label2name,
    )
Ejemplo n.º 5
0
def caculate_fitness_for_first_time(mask_input, gpu_id, fitness_id,
                                    A2B_or_B2A):

    ###### Definition of variables ######
    torch.cuda.set_device(gpu_id)
    #print("GPU_ID is%d\n"%(gpu_id))
    if A2B_or_B2A == 'A2B':
        netG_A2B = Generator(opt.input_nc, opt.output_nc)
        netD_B = Discriminator(opt.output_nc)
        netG_A2B.cuda(gpu_id)
        netD_B.cuda(gpu_id)
        model = Generator(opt.input_nc, opt.output_nc)
        model.cuda(gpu_id)
        netG_A2B.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))
        netD_B.load_state_dict(torch.load('/cache/models/netD_B.pth'))
        model.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))
        model.eval()
        netD_B.eval()
        netG_A2B.eval()

    elif A2B_or_B2A == 'B2A':
        netG_B2A = Generator(opt.output_nc, opt.input_nc)
        netD_A = Discriminator(opt.input_nc)
        netG_B2A.cuda(gpu_id)
        netD_A.cuda(gpu_id)
        model = Generator(opt.input_nc, opt.output_nc)
        model.cuda(gpu_id)
        netG_B2A.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))
        netD_A.load_state_dict(torch.load('/cache/models/netD_A.pth'))
        model.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))
        model.eval()
        netD_A.eval()
        netG_B2A.eval()

    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    fitness = 0
    cfg_mask = compute_layer_mask(mask_input, mask_chns)
    cfg_full_mask = [y for x in cfg_mask for y in x]
    cfg_full_mask = np.array(cfg_full_mask)
    cfg_id = 0
    start_mask = np.ones(3)
    end_mask = cfg_mask[cfg_id]

    for m in model.modules():
        if isinstance(m, nn.Conv2d):

            mask = np.ones(m.weight.data.shape)

            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))

            mask[:, idx0.tolist(), :, :] = 0
            mask[idx1.tolist(), :, :, :] = 0
            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)

            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            idx_mask = np.argwhere(np.asarray(np.ones(mask.shape) - mask))

            m.weight.data[:, idx0.tolist(), :, :].requires_grad = False
            m.weight.data[idx1.tolist(), :, :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            if cfg_id < len(cfg_mask):
                end_mask = cfg_mask[cfg_id]
            continue
        elif isinstance(m, nn.ConvTranspose2d):

            mask = np.ones(m.weight.data.shape)
            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask

            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))

            mask[idx0.tolist(), :, :, :] = 0

            mask[:, idx1.tolist(), :, :] = 0

            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)
            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            m.weight.data[idx0.tolist(), :, :, :].requires_grad = False
            m.weight.data[:, idx1.tolist(), :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            end_mask = cfg_mask[cfg_id]
            continue

    # Dataset loader
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    lamda_loss_ID = 5.0
    lamda_loss_G = 1.0
    lamda_loss_cycle = 10.0

    with torch.no_grad():

        transforms_ = [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        dataloader = DataLoader(ImageDataset(opt.dataroot,
                                             transforms_=transforms_,
                                             mode='val'),
                                batch_size=opt.batchSize,
                                shuffle=False,
                                drop_last=True)

        Loss_resemble_G = 0
        if A2B_or_B2A == 'A2B':
            for i, batch in enumerate(dataloader):
                # Set model input
                real_A = Variable(input_A.copy_(batch['A']))

                # GAN loss
                fake_B = model(real_A)
                fake_B_full_model = netG_A2B(real_A)

                # Fake loss
                pred_fake = netD_B(fake_B.detach())

                pred_fake_full = netD_B(fake_B_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                Loss_resemble_G = Loss_resemble_G + loss_D_fake

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask.shape) - cfg_full_mask) * lambda_prune
            print("A2B first generation")
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask.shape) - cfg_full_mask)))
            print("fitness is %f \n" % (fitness))

            current_fitness_A2B[fitness_id] = fitness.item()

        elif A2B_or_B2A == 'B2A':
            for i, batch in enumerate(dataloader):

                real_B = Variable(input_B.copy_(batch['B']))

                fake_A = model(real_B)
                fake_A_full_model = netG_B2A(real_B)

                pred_fake = netD_A(fake_A.detach())

                pred_fake_full = netD_A(fake_A_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                Loss_resemble_G = Loss_resemble_G + loss_D_fake

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask.shape) - cfg_full_mask) * lambda_prune
            print("B2A first generation")
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask.shape) - cfg_full_mask)))
            print("fitness is %f \n" % (fitness))

            current_fitness_B2A[fitness_id] = fitness.item()
Ejemplo n.º 6
0
# noinspection PyArgumentList
target_fake = torch.Tensor(opt.batchSize).fill_(0.0).to(device)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [
    transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
    transforms.RandomCrop(opt.size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

data_set = ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True)
data_loader = DataLoader(data_set,
                         batch_size=opt.batchSize,
                         shuffle=True,
                         num_workers=opt.n_cpu)

loss_meters: Dict[str, AverageValueMeter] = {
    'loss_G_meter': AverageValueMeter(),
    'loss_G_identity_meter': AverageValueMeter(),
    'loss_G_GAN_meter': AverageValueMeter(),
    'loss_G_cycle_meter': AverageValueMeter(),
    'loss_D_meter': AverageValueMeter()
}
# Loss plot
# logger = Logger(opt.n_epochs, len(data_loader))
loss_logger = VisdomPlotLogger('line', opts={'title': 'Loss'})
Ejemplo n.º 7
0
def main():
    config.save = 'ckpt/{}'.format(config.save)
    create_exp_dir(config.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    state = torch.load(os.path.join(config.load_path, 'arch.pt'))
    # Model #######################################
    model = NAS_GAN_Infer(state['alpha'], state['beta'], state['ratio'], num_cell=config.num_cell, op_per_cell=config.op_per_cell, width_mult_list=config.width_mult_list, 
                          loss_weight=config.loss_weight, loss_func=config.loss_func, before_act=config.before_act, quantize=config.quantize)

    flops, params = profile(model, inputs=(torch.randn(1, 3, 510, 350),), custom_ops=custom_ops)
    flops = model.forward_flops(size=(3, 510, 350))
    logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)

    model = torch.nn.DataParallel(model).cuda()

    if type(config.pretrain) == str:
        state_dict = torch.load(config.pretrain)
        model.load_state_dict(state_dict)
    # else:
    #     features = [model.module.cells, model.module.conv_first, model.module.trunk_conv, model.module.upconv1, 
    #                 model.module.upconv2, model.module.HRconv, model.module.conv_last]
    #     init_weight(features, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu')

    teacher_model = RRDBNet(3, 3, 64, 23, gc=32)
    teacher_model.load_state_dict(torch.load(config.generator_A2B), strict=True)
    teacher_model = torch.nn.DataParallel(teacher_model).cuda()
    teacher_model.eval()

    for param in teacher_model.parameters():
        param.require_grads = False

    # Optimizer ###################################
    base_lr = config.lr
    parameters = []
    parameters += list(model.module.cells.parameters())
    parameters += list(model.module.conv_first.parameters())
    parameters += list(model.module.trunk_conv.parameters())
    parameters += list(model.module.upconv1.parameters())
    parameters += list(model.module.upconv2.parameters())
    parameters += list(model.module.HRconv.parameters())
    parameters += list(model.module.conv_last.parameters())

    if config.opt == 'Adam':
        optimizer = torch.optim.Adam(
            parameters,
            lr=base_lr,
            betas=config.betas)
    elif config.opt == 'Sgd':
        optimizer = torch.optim.SGD(
            parameters,
            lr=base_lr,
            momentum=config.momentum,
            weight_decay=config.weight_decay)
    else:
        logging.info("Wrong Optimizer Type.")
        sys.exit()

    # lr policy ##############################
    total_iteration = config.nepochs * config.niters_per_epoch

    if config.lr_schedule == 'linear':
        lr_policy = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR(config.nepochs, 0, config.decay_epoch).step)
    elif config.lr_schedule == 'exponential':
        lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, config.lr_decay)
    elif config.lr_schedule == 'multistep':
        lr_policy = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma)
    else:
        logging.info("Wrong Learning Rate Schedule Type.")
        sys.exit()


    # data loader ############################

    transforms_ = [ transforms.RandomCrop(config.image_height), 
                    transforms.RandomHorizontalFlip(), 
                    transforms.ToTensor()]
    train_loader_model = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, unaligned=True), 
                        batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)

    transforms_ = [ transforms.ToTensor()]
    test_loader = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, mode='val'), 
                        batch_size=1, shuffle=False, num_workers=config.num_workers)


    if config.eval_only:
        logging.info('Eval: psnr = %f', infer(0, model, test_loader, logger))
        sys.exit(0)

    tbar = tqdm(range(config.nepochs), ncols=80)
    for epoch in tbar:
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))

        # training
        tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs))
        train(train_loader_model, model, teacher_model, optimizer, lr_policy, logger, epoch)
        torch.cuda.empty_cache()
        lr_policy.step()

        # validation
        if epoch and not (epoch+1) % config.eval_epoch:
            tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs))
            
            with torch.no_grad():
                model.prun_mode = None

                valid_psnr = infer(epoch, model, test_loader, logger)

                logger.add_scalar('psnr/val', valid_psnr, epoch)
                logging.info("Epoch %d: valid_psnr %.3f"%(epoch, valid_psnr))
                
                logger.add_scalar('flops/val', flops, epoch)
                logging.info("Epoch %d: flops %.3f"%(epoch, flops))

            save(model, os.path.join(config.save, 'weights_%d.pt'%epoch))

    save(model, os.path.join(config.save, 'weights.pt'))
Ejemplo n.º 8
0
target_real = Variable(Tensor(opt.batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(opt.batchSize).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [
    transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
    transforms.RandomCrop(opt.size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
dataloader = DataLoader(ImageDataset(dataset_dir,
                                     transforms_=transforms_,
                                     unaligned=True),
                        batch_size=opt.batchSize,
                        shuffle=True,
                        num_workers=opt.n_cpu,
                        drop_last=True)

# Loss plot
# logger = Logger(opt.n_epochs, len(dataloader))
###################################

###### Training ######
N = len(dataloader)
print('N:', N)  # 1334
loss_G_lst, loss_D_lst, loss_G_GAN_lst, loss_G_cycle_lst, loss_G_identity_lst = [], [], [], [], []
for epoch in range(opt.epoch, opt.n_epochs):
Ejemplo n.º 9
0
        'query': 'query/*.jpg',
    }

    df_train = create_market_df('train')
    dfs_test = {x: create_market_df(x) for x in ['test', 'query']}

    data_transform_test = transforms.Compose([
        transforms.Resize([256, 256]),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    datasets_test = {
        x: ImageDataset(dfs_test[x]['path'],
                        transform=data_transform_test,
                        is_train=False)
        for x in ['test', 'query']
    }

    dataloaders_test = {
        x: DataLoader(datasets_test[x],
                      batch_size=opt['batch_size_test'],
                      shuffle=False,
                      num_workers=opt['nworkers'])
        for x in datasets_test.keys()
    }

    evaluation = Evaluation(dfs_test['test'], dfs_test['query'],
                            dataloaders_test['test'],
                            dataloaders_test['query'], opt['cuda'])
Ejemplo n.º 10
0
                    default=9,
                    help="number of residual blocks")
parser.add_argument("--lambda_cycle",
                    type=float,
                    default=10.0,
                    help="cycle loss weight")
args = parser.parse_args()
print(args)

#############################
# Dataloaders; fake buffers #
#############################
# TODO: make path_to_data exactly path_to_data.
# Now script requires to be launched only from location of *dataset_name* dir.
train_loader = DataLoader(ImageDataset(path_to_data='datasets/%s' %
                                       args.dataset_name,
                                       size=(args.img_height, args.img_width),
                                       mode='train'),
                          batch_size=args.batch_size,
                          num_workers=2,
                          shuffle=True)

test_loader = DataLoader(ImageDataset(path_to_data='datasets/%s' %
                                      args.dataset_name,
                                      size=(args.img_height, args.img_width),
                                      mode='test'),
                         batch_size=5,
                         num_workers=2,
                         shuffle=True)

fake_A_buffer = FakeImageBuffer()
fake_B_buffer = FakeImageBuffer()
input_X = Tensor(batch_size, input_nc, image_size, image_size)
input_Y = Tensor(batch_size, output_nc, image_size, image_size)

# labels
real_labels = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
fake_labels = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False)

transforms = transforms.Compose([
    transforms.Resize(int(image_size * 1.2), Image.BICUBIC),
    transforms.RandomCrop(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

dataset = ImageDataset(dataroot=dataroot, transforms=transforms, aligned=True)
dataloader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers)

p = utils.Plotter(['Loss_G', 'Loss_Dx', 'Loss_Dy'])

print('Start training.')
for epoch in range(epochs):
    start_time = time.monotonic()
    for idx, batch in enumerate(dataloader):
        real_X = input_X.copy_(batch['X_trans'])
        real_Y = input_Y.copy_(batch['Y_trans'])

        raw_X = batch['X_raw']
Ejemplo n.º 12
0
if not os.path.exists(os.path.join(test_output_path, 'A')):
    os.makedirs(os.path.join(test_output_path, 'A'))
if not os.path.exists(os.path.join(test_output_path, 'B')):
    os.makedirs(os.path.join(test_output_path, 'B'))
Tensor = torch.cuda.FloatTensor
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)

# Dataset loader
transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
test_dataloader = DataLoader(ImageDataset(dataset_dir,
                                          transforms_=transforms_,
                                          mode='test'),
                             batch_size=opt.batchSize,
                             shuffle=False,
                             num_workers=opt.n_cpu)


def pruning_generate(model, state_dict):

    parameters_to_prune = []
    for (name, m) in model.named_modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            m = prune.custom_from_mask(m,
                                       name='weight',
                                       mask=state_dict[name + ".weight_mask"])
Ejemplo n.º 13
0
def main():
    opts = get_argparser().parse_args()

    # dataset
    train_trainsform = transforms.Compose([
        transforms.RandomCrop(size=512, pad_if_needed=True),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
    ])

    val_transform = transforms.Compose([transforms.ToTensor()])

    train_loader = data.DataLoader(data.ConcatDataset([
        ImageDataset(root='datasets/data/CLIC/train',
                     transform=train_trainsform),
        ImageDataset(root='datasets/data/CLIC/valid',
                     transform=train_trainsform),
    ]),
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)

    val_loader = data.DataLoader(ImageDataset(root='datasets/data/kodak',
                                              transform=val_transform),
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1)

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Train set: %d, Val set: %d" %
          (len(train_loader.dataset), len(val_loader.dataset)))
    model = AutoEncoder(C=128, M=128, in_chan=3, out_chan=3).to(device)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4,
                                 weight_decay=1e-5)

    # checkpoint
    best_score = 0.0
    cur_epoch = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        model.load_state_dict(torch.load(opts.ckpt))
    else:
        print("[!] Retrain")

    if opts.loss_type == 'ssim':
        criterion = SSIM_Loss(data_range=1.0, size_average=True, channel=3)
    else:
        criterion = MS_SSIM_Loss(data_range=1.0,
                                 size_average=True,
                                 channel=3,
                                 nonnegative_ssim=True)

    #==========   Train Loop   ==========#
    for cur_epoch in range(opts.total_epochs):
        # =====  Train  =====
        model.train()
        for cur_step, images in enumerate(train_loader):
            images = images.to(device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, images)
            loss.backward()

            optimizer.step()

            if (cur_step) % opts.log_interval == 0:
                print("Epoch %d, Batch %d/%d, loss=%.6f" %
                      (cur_epoch, cur_step, len(train_loader), loss.item()))

        # =====  Save Latest Model  =====
        torch.save(model.state_dict(), 'latest_model.pt')

        # =====  Validation  =====
        print("Val on Kodak dataset...")
        best_score = 0.0
        cur_score = test(opts, model, val_loader, criterion, device)
        print("%s = %.6f" % (opts.loss_type, cur_score))
        # =====  Save Best Model  =====
        if cur_score > best_score:  # save best model
            best_score = cur_score
            torch.save(model.state_dict(), 'best_model.pt')
            print("Best model saved as best_model.pt")
def main():
    parser = argparse.ArgumentParser(description='model training')
    parser.add_argument('--save_dir',
                        type=str,
                        default='logs/tmp',
                        help='save model directory')
    # transforms
    parser.add_argument('--train_size',
                        type=int,
                        default=[224],
                        nargs='+',
                        help='train image size')
    parser.add_argument('--test_size',
                        type=int,
                        default=[224, 224],
                        nargs='+',
                        help='test image size')
    parser.add_argument('--h_filp',
                        action='store_true',
                        help='do horizontal flip')
    # dataset
    parser.add_argument('--dataset_dir',
                        type=str,
                        default='datasets',
                        help='datasets path')
    parser.add_argument('--valid_pect',
                        type=float,
                        default=0.2,
                        help='validation percent split from train')
    parser.add_argument('--train_bs',
                        type=int,
                        default=64,
                        help='train images per batch')
    parser.add_argument('--test_bs',
                        type=int,
                        default=128,
                        help='test images per batch')
    # training
    parser.add_argument('--no_gpu',
                        action='store_true',
                        help='whether use gpu')
    parser.add_argument('--gpus',
                        type=str,
                        default='0',
                        help='gpus to use in training')
    parser.add_argument('--opt_func',
                        type=str,
                        default='Adam',
                        help='optimizer function')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        help='base learning rate')
    parser.add_argument('--steps',
                        type=int,
                        default=(60, 90),
                        nargs='+',
                        help='learning rate decay strategy')
    parser.add_argument('--factor',
                        type=float,
                        default=0.1,
                        help='learning rate decay factor')
    parser.add_argument('--wd', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', default=0.9, help='training momentum')
    parser.add_argument('--max_epoch',
                        type=int,
                        default=120,
                        help='number of training epochs')
    parser.add_argument('--log_interval',
                        type=int,
                        default=50,
                        help='intermediate printing')
    parser.add_argument('--save_step',
                        type=int,
                        default=20,
                        help='save model every save_step')

    args = parser.parse_args()

    mkdir_if_missing(args.save_dir)
    log_path = os.path.join(args.save_dir, 'log.txt')
    with open(log_path, 'w') as f:
        f.write('{}'.format(args))

    device = "cuda:{}".format(args.gpus) if not args.no_gpu else "cpu"
    if not args.no_gpu:
        cudnn.benchmark = True

    # define train transforms and test transforms
    totensor = T.ToTensor()
    normalize = T.Normalize(mean=[0.491, 0.482, 0.446],
                            std=[0.202, 0.199, 0.201])
    train_tfms = list()
    train_size = args.train_size[0] if len(
        args.train_size) == 1 else args.train_size
    train_tfms.append(T.RandomResizedCrop(train_size))
    if args.h_filp:
        train_tfms.append(T.RandomHorizontalFlip())
    train_tfms.append(totensor)
    train_tfms.append(normalize)
    train_tfms = T.Compose(train_tfms)

    test_tfms = list()
    test_size = (args.test_size[0], args.test_size[0]) if len(
        args.test_size) == 1 else args.test_size
    test_tfms.append(T.Resize(test_size))
    test_tfms.append(totensor)
    test_tfms.append(normalize)
    test_tfms = T.Compose(test_tfms)

    # get dataloader
    train_list, valid_list, label2name = split_dataset(args.dataset_dir,
                                                       args.valid_pect)
    trainset = ImageDataset(train_list, train_tfms)
    validset = ImageDataset(valid_list, test_tfms)

    train_loader = DataLoader(trainset,
                              batch_size=args.train_bs,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    valid_loader = DataLoader(validset,
                              batch_size=args.test_bs,
                              shuffle=False,
                              num_workers=8,
                              pin_memory=True)

    # define network
    net = get_resnet50(len(label2name), pretrain=True)
    # layer_groups = [nn.Sequential(*flatten_model(net))]

    # define loss
    ce_loss = nn.CrossEntropyLoss()

    # define optimizer and lr scheduler
    if args.opt_func == 'Adam':
        optimizer = getattr(torch.optim, args.opt_func)(net.parameters(),
                                                        weight_decay=args.wd)
    else:
        optimizer = getattr(torch.optim, args.opt_func)(net.parameters(),
                                                        weight_decay=args.wd,
                                                        momentum=args.momentum)
    lr_scheduler = LRScheduler(base_lr=args.lr,
                               step=args.steps,
                               factor=args.factor)

    train(
        args=args,
        network=net,
        train_data=train_loader,
        valid_data=valid_loader,
        optimizer=optimizer,
        criterion=ce_loss,
        lr_scheduler=lr_scheduler,
        device=device,
        log_path=log_path,
        label2name=label2name,
    )
# Dataset loader
transforms_ = [
    transforms.Resize(image_size, Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

# For training set
# The `ImageDataset` is in dataset.py. Check it out to see what it does.
# dataloader = DataLoader(ImageDataset(data_path, transforms_=transforms_, unaligned=True),
#                         batch_size=batch_size, shuffle=True, num_workers=n_cpu)

# For test set
dataloader_test = DataLoader(ImageDataset(data_path,
                                          transforms_=transforms_,
                                          mode='test'),
                             batch_size=1,
                             shuffle=True,
                             num_workers=n_cpu)

prev_time = time.time()
iter = load_iter

# Training
for epoch in range(start_epoch, n_epochs):
    print('current epoch:', epoch)
    dataloader = DataLoader(ImageDataset(data_path,
                                         transforms_=transforms_,
                                         unaligned=True),
                            batch_size=batch_size,
Ejemplo n.º 16
0
def train_from_mask():

    #load best fitness binary masks
    mask_input_A2B = np.loadtxt("/cache/GA/txt/best_fitness_A2B.txt")
    mask_input_B2A = np.loadtxt("/cache/GA/txt/best_fitness_B2A.txt")

    cfg_mask_A2B = compute_layer_mask(mask_input_A2B, mask_chns)
    cfg_mask_B2A = compute_layer_mask(mask_input_B2A, mask_chns)

    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netG_A2B = Generator(opt.output_nc, opt.input_nc)
    model_A2B = Generator_Prune(cfg_mask_A2B)
    model_B2A = Generator_Prune(cfg_mask_B2A)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    netG_A2B.load_state_dict(torch.load('/cache/log/output/netG_A2B.pth'))
    netG_B2A.load_state_dict(torch.load('/cache/log/output/netG_B2A.pth'))

    netD_A.load_state_dict(torch.load('/cache/log/output/netD_A.pth'))
    netD_B.load_state_dict(torch.load('/cache/log/output/netD_B.pth'))

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask_A2B[layer_id_in_cfg]

    for [m0, m1] in zip(netG_A2B.modules(), model_A2B.modules()):

        if isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(
                idx0.size, idx1.size))

            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()

            m1.bias.data = m0.bias.data[idx1.tolist()].clone()

            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(
                    cfg_mask_A2B):  # do not change in Final FC
                end_mask = cfg_mask_A2B[layer_id_in_cfg]
                print(layer_id_in_cfg)
        elif isinstance(m0, nn.ConvTranspose2d):
            print('Into ConvTranspose...')
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(
                idx0.size, idx1.size))

            w1 = m0.weight.data[idx0.tolist(), :, :, :].clone()
            w1 = w1[:, idx1.tolist(), :, :].clone()
            m1.weight.data = w1.clone()
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_A2B):
                end_mask = cfg_mask_A2B[layer_id_in_cfg]

    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask_B2A[layer_id_in_cfg]

    for [m0, m1] in zip(netG_B2A.modules(), model_B2A.modules()):

        if isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(
                idx0.size, idx1.size))

            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()

            m1.bias.data = m0.bias.data[idx1.tolist()].clone()

            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_B2A):
                end_mask = cfg_mask_B2A[layer_id_in_cfg]
                print(layer_id_in_cfg)
        elif isinstance(m0, nn.ConvTranspose2d):
            print('Into ConvTranspose...')
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask)))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask)))
            print('In shape: {:d}, Out shape {:d}.'.format(
                idx0.size, idx1.size))

            w1 = m0.weight.data[idx0.tolist(), :, :, :].clone()
            w1 = w1[:, idx1.tolist(), :, :].clone()
            m1.weight.data = w1.clone()
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask
            if layer_id_in_cfg < len(cfg_mask_B2A):
                end_mask = cfg_mask_B2A[layer_id_in_cfg]

    # Dataset loader

    netD_A = torch.nn.DataParallel(netD_A).cuda()
    netD_B = torch.nn.DataParallel(netD_B).cuda()
    model_A2B = torch.nn.DataParallel(model_A2B).cuda()
    model_B2A = torch.nn.DataParallel(model_B2A).cuda()

    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    lamda_loss_ID = 5.0
    lamda_loss_G = 1.0
    lamda_loss_cycle = 10.0
    optimizer_G = torch.optim.Adam(itertools.chain(model_A2B.parameters(),
                                                   model_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    transforms_ = [
        transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(opt.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True,
                                         mode='train'),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            drop_last=True)

    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = model_A2B(real_B)
            loss_identity_B = criterion_identity(
                same_B, real_B) * lamda_loss_ID  #initial 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = model_B2A(real_A)
            loss_identity_A = criterion_identity(
                same_A, real_A) * lamda_loss_ID  #initial 5.0

            # GAN loss
            fake_B = model_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(
                pred_fake, target_real) * lamda_loss_G  #initial 1.0

            fake_A = model_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(
                pred_fake, target_real) * lamda_loss_G  #initial 1.0

            # Cycle loss
            recovered_A = model_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(
                recovered_A, real_A) * lamda_loss_cycle  #initial 10.0

            recovered_B = model_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(
                recovered_B, real_B) * lamda_loss_cycle  #initial 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()

        print(
            "epoch:%d  Loss G:%4f  LossID_A:%4f LossID_B:%4f  Loss_G_A2B:%4f  Loss_G_B2A:%4f  Loss_Cycle_ABA:%4f  Loss_Cycle_BAB:%4f "
            % (epoch, loss_G, loss_identity_A, loss_identity_B, loss_GAN_A2B,
               loss_GAN_B2A, loss_cycle_ABA, loss_cycle_BAB))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if epoch % 20 == 0:

            # Save models checkpoints
            torch.save(model_A2B.module.state_dict(),
                       '/cache/log/output/A2B_%d.pth' % (epoch))
            torch.save(model_B2A.module.state_dict(),
                       '/cache/log/output/B2A_%d.pth' % (epoch))
Ejemplo n.º 17
0
def main(args):
    torch.manual_seed(0)
    if args.mb_D:
        raise NotImplementedError('mb_D not implemented')
        assert args.batch_size > 1, 'batch size needs to be larger than 1 if mb_D'

    if args.img_norm != 'znorm':
        raise NotImplementedError('{} not implemented'.format(args.img_norm))

    assert args.act in ['relu', 'mish'], 'args.act = {}'.format(args.act)

    modelarch = 'C_{0}_{1}_{2}_{3}_{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}{14}{15}{16}{17}{18}{19}{20}{21}{22}'.format(
        args.size, args.batch_size, args.lr,  args.n_epochs, args.decay_epoch, # 0, 1, 2, 3, 4
        '_G' if args.G_extra else '',  # 5
        '_D' if args.D_extra else '',  # 6
        '_U' if args.upsample else '',  # 7
        '_S' if args.slow_D else '',  # 8
        '_RL{}-{}'.format(args.start_recon_loss_val, args.start_recon_loss_val),  # 9
        '_GL{}-{}'.format(args.start_gan_loss_val, args.start_gan_loss_val),  # 10
        '_prop' if args.keep_prop else '',  # 11
        '_' + args.img_norm,  # 12
        '_WL' if args.wasserstein else '',  # 13
        '_MBD' if args.mb_D else '',  # 14
        '_FM' if args.fm_loss else '',  # 15
        '_BF{}'.format(args.buffer_size) if args.buffer_size != 50 else '',  # 16
        '_N' if args.add_noise else '',  # 17
        '_L{}'.format(args.load_iter) if args.load_iter > 0 else '',  # 18
        '_res{}'.format(args.n_resnet_blocks),  # 19
        '_n{}'.format(args.data_subset) if args.data_subset is not None else '',  # 20
        '_{}'.format(args.optim),  # 21
        '_{}'.format(args.act))  # 22

    samples_path = os.path.join(args.output_dir, modelarch, 'samples')
    safe_mkdirs(samples_path)
    model_path = os.path.join(args.output_dir, modelarch, 'models')
    safe_mkdirs(model_path)
    test_path = os.path.join(args.output_dir, modelarch, 'test')
    safe_mkdirs(test_path)

    # Definition of variables ######
    # Networks
    netG_A2B = Generator(args.input_nc, args.output_nc, img_size=args.size,
                         extra_layer=args.G_extra, upsample=args.upsample,
                         keep_weights_proportional=args.keep_prop,
                         n_residual_blocks=args.n_resnet_blocks,
                         act=args.act)
    netG_B2A = Generator(args.output_nc, args.input_nc, img_size=args.size,
                         extra_layer=args.G_extra, upsample=args.upsample,
                         keep_weights_proportional=args.keep_prop,
                         n_residual_blocks=args.n_resnet_blocks,
                         act=args.act)
    netD_A = Discriminator(args.input_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size)
    netD_B = Discriminator(args.output_nc, extra_layer=args.D_extra, mb_D=args.mb_D, x_size=args.size)

    if args.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    if args.load_iter == 0:
        netG_A2B.apply(weights_init_normal)
        netG_B2A.apply(weights_init_normal)
        netD_A.apply(weights_init_normal)
        netD_B.apply(weights_init_normal)
    else:
        netG_A2B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_A2B_{}.pth'.format(args.load_iter))))
        netG_B2A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'G_B2A_{}.pth'.format(args.load_iter))))
        netD_A.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_A_{}.pth'.format(args.load_iter))))
        netD_B.load_state_dict(torch.load(os.path.join(args.load_dir, 'models', 'D_B_{}.pth'.format(args.load_iter))))

        netG_A2B.train()
        netG_B2A.train()
        netD_A.train()
        netD_B.train()

    # Lossess
    criterion_GAN = wasserstein_loss if args.wasserstein else torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    feat_criterion = torch.nn.HingeEmbeddingLoss()

    # I could also update D only if iters % 2 == 0
    lr_G = args.lr
    lr_D = args.lr / 2 if args.slow_D else args.lr

    # Optimizers & LR schedulers
    if args.optim == 'adam':
        optim = torch.optim.Adam
    elif args.optim == 'radam':
        optim = RAdam
    elif args.optim == 'ranger':
        optim = Ranger
    elif args.optim == 'rangerlars':
        optim = RangerLars
    else:
        raise NotImplementedError('args.optim = {} not implemented'.format(args.optim))

    optimizer_G = optim(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                        lr=args.lr, betas=(0.5, 0.999))
    optimizer_D_A = optim(netD_A.parameters(), lr=lr_G, betas=(0.5, 0.999))
    optimizer_D_B = optim(netD_B.parameters(), lr=lr_D, betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(args.n_epochs, args.load_iter, args.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if args.cuda else torch.Tensor
    input_A = Tensor(args.batch_size, args.input_nc, args.size, args.size)
    input_B = Tensor(args.batch_size, args.output_nc, args.size, args.size)
    target_real = Variable(Tensor(args.batch_size).fill_(1.0), requires_grad=False)
    target_fake = Variable(Tensor(args.batch_size).fill_(0.0), requires_grad=False)

    fake_A_buffer = ReplayBuffer(args.buffer_size)
    fake_B_buffer = ReplayBuffer(args.buffer_size)

    # Transforms and dataloader for training set
    transforms_ = []
    if args.resize_crop:
        transforms_ += [transforms.Resize(int(args.size*1.12), Image.BICUBIC),
                        transforms.RandomCrop(args.size)]
    else:
        transforms_ += [transforms.Resize(args.size, Image.BICUBIC)]

    if args.horizontal_flip:
        transforms_ += [transforms.RandomHorizontalFlip()]

    transforms_ += [transforms.ToTensor()]

    if args.add_noise:
        transforms_ += [transforms.Lambda(lambda x: x + torch.randn_like(x))]

    transforms_norm = []
    if args.img_norm == 'znorm':
        transforms_norm += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    elif 'scale01' in args.img_norm:
        transforms_norm += [transforms.Lambda(lambda x: x.mul(1/255))]  # TODO this might not preserve the dimensions. is .mul per element?
        if 'flip' in args.img_norm:
            transforms_norm += [transforms.Lambda(lambda x: (x - 1).abs())]  # TODO this might not preserve the dimensions. is .mul per element?
    else:
        raise ValueError('wrong --img_norm. only znorm|scale01|scale01flip')

    transforms_ += transforms_norm

    dataloader = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_, unaligned=True, n=args.data_subset),
                            batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu)

    # Transforms and dataloader for test set
    transforms_test_ = [transforms.Resize(args.size, Image.BICUBIC),
                        transforms.ToTensor()]
    transforms_test_ += transforms_norm

    dataloader_test = DataLoader(ImageDataset(args.dataroot, transforms_=transforms_test_, mode='test'),
                                 batch_size=args.batch_size, shuffle=False, num_workers=args.n_cpu)
    # Training ######
    if args.load_iter == 0 and args.load_epoch != 0:
        print('****** NOTE: args.load_iter == 0 and args.load_epoch != 0 ******')

    iter = args.load_iter
    prev_time = time.time()
    n_test = 10e10 if args.n_test is None else args.n_test
    n_sample = 10e10 if args.n_sample is None else args.n_sample

    rl_delta_x = args.n_epochs - args.recon_loss_epoch
    rl_delta_y = args.end_recon_loss_val - args.start_recon_loss_val

    gan_delta_x = args.n_epochs - args.gan_loss_epoch
    gan_delta_y = args.end_gan_loss_val - args.start_gan_loss_val

    for epoch in range(args.load_epoch, args.n_epochs):

        rl_effective_epoch = max(epoch - args.recon_loss_epoch, 0)
        recon_loss_rate = args.start_recon_loss_val + rl_effective_epoch * (rl_delta_y / rl_delta_x)

        gan_effective_epoch = max(epoch - args.gan_loss_epoch, 0)
        gan_loss_rate = args.start_gan_loss_val + gan_effective_epoch * (gan_delta_y / gan_delta_x)

        id_loss_rate = 5.0

        for i, batch in enumerate(dataloader):
            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            # Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(same_B, real_B)
            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(same_A, real_A)

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake, _ = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

            fake_A = netG_B2A(real_B)
            pred_fake, _ = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)

            # Total loss
            loss_G = (loss_identity_A + loss_identity_B) * id_loss_rate
            loss_G += (loss_GAN_A2B + loss_GAN_B2A) * gan_loss_rate
            loss_G += (loss_cycle_ABA + loss_cycle_BAB) * recon_loss_rate

            loss_G.backward()

            optimizer_G.step()

            # Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real, _ = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake, _ = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            loss_D_A = (loss_D_real + loss_D_fake) * 0.5

            if args.fm_loss:
                pred_real, feats_real = netD_A(real_A)
                pred_fake, feats_fake = netD_A(fake_A.detach())

                fm_loss_A = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda)

                loss_D_A = loss_D_A * 0.1 + fm_loss_A * 0.9

            loss_D_A.backward()

            optimizer_D_A.step()

            # Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real, _ = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake, _ = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            loss_D_B = (loss_D_real + loss_D_fake)*0.5

            if args.fm_loss:
                pred_real, feats_real = netD_B(real_B)
                pred_fake, feats_fake = netD_B(fake_B.detach())

                fm_loss_B = get_fm_loss(feats_real, feats_fake, feat_criterion, args.cuda)

                loss_D_B = loss_D_B * 0.1 + fm_loss_B * 0.9

            loss_D_B.backward()

            optimizer_D_B.step()

            if iter % args.log_interval == 0:

                print('---------------------')
                print('GAN loss:', as_np(loss_GAN_A2B), as_np(loss_GAN_B2A))
                print('Identity loss:', as_np(loss_identity_A), as_np(loss_identity_B))
                print('Cycle loss:', as_np(loss_cycle_ABA), as_np(loss_cycle_BAB))
                print('D loss:', as_np(loss_D_A), as_np(loss_D_B))
                if args.fm_loss:
                    print('fm loss:', as_np(fm_loss_A), as_np(fm_loss_B))
                print('recon loss rate:', recon_loss_rate)
                print('time:', time.time() - prev_time)
                prev_time = time.time()

            if iter % args.plot_interval == 0:
                pass

            if iter % args.image_save_interval == 0:
                samples_path_ = os.path.join(samples_path, str(iter / args.image_save_interval))
                safe_mkdirs(samples_path_)

                # New savedir
                test_pth_AB = os.path.join(test_path, str(iter / args.image_save_interval), 'AB')
                test_pth_BA = os.path.join(test_path, str(iter / args.image_save_interval), 'BA')

                safe_mkdirs(test_pth_AB)
                safe_mkdirs(test_pth_BA)

                for j, batch_ in enumerate(dataloader_test):

                    real_A_test = Variable(input_A.copy_(batch_['A']))
                    real_B_test = Variable(input_B.copy_(batch_['B']))

                    fake_AB_test = netG_A2B(real_A_test)
                    fake_BA_test = netG_B2A(real_B_test)

                    if j < n_sample:
                        recovered_ABA_test = netG_B2A(fake_AB_test)
                        recovered_BAB_test = netG_A2B(fake_BA_test)

                        fn = os.path.join(samples_path_, str(j))
                        imageio.imwrite(fn + '.A.jpg', tensor2image(real_A_test[0], args.img_norm))
                        imageio.imwrite(fn + '.B.jpg', tensor2image(real_B_test[0], args.img_norm))
                        imageio.imwrite(fn + '.BA.jpg', tensor2image(fake_BA_test[0], args.img_norm))
                        imageio.imwrite(fn + '.AB.jpg', tensor2image(fake_AB_test[0], args.img_norm))
                        imageio.imwrite(fn + '.ABA.jpg', tensor2image(recovered_ABA_test[0], args.img_norm))
                        imageio.imwrite(fn + '.BAB.jpg', tensor2image(recovered_BAB_test[0], args.img_norm))

                    if j < n_test:
                        fn_A = os.path.basename(batch_['img_A'][0])
                        imageio.imwrite(os.path.join(test_pth_AB, fn_A), tensor2image(fake_AB_test[0], args.img_norm))

                        fn_B = os.path.basename(batch_['img_B'][0])
                        imageio.imwrite(os.path.join(test_pth_BA, fn_B), tensor2image(fake_BA_test[0], args.img_norm))

            if iter % args.model_save_interval == 0:
                # Save models checkpoints
                torch.save(netG_A2B.state_dict(), os.path.join(model_path, 'G_A2B_{}.pth'.format(iter)))
                torch.save(netG_B2A.state_dict(), os.path.join(model_path, 'G_B2A_{}.pth'.format(iter)))
                torch.save(netD_A.state_dict(), os.path.join(model_path, 'D_A_{}.pth'.format(iter)))
                torch.save(netD_B.state_dict(), os.path.join(model_path, 'D_B_{}.pth'.format(iter)))

            iter += 1

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()
Ejemplo n.º 18
0
target_real = Variable(Tensor(args.batchSize).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(args.batchSize).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Dataset loader
transforms_ = [
    transforms.Resize(int(args.size * 1.12), Image.BICUBIC),
    transforms.RandomCrop(args.size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
dataloader = DataLoader(ImageDataset(dataset_dir,
                                     transforms_=transforms_,
                                     unaligned=True),
                        batch_size=args.batchSize,
                        shuffle=True,
                        num_workers=args.n_cpu,
                        drop_last=True)

test_transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

test_dataloader = DataLoader(ImageDataset(dataset_dir,
                                          transforms_=test_transforms_,
                                          mode='test'),
                             batch_size=1,
Ejemplo n.º 19
0
def train(use_cuda=False):
    data = "./pics/data"
    labels = "./pics/labels"
    csv_path = "./pics/info.csv"

    batch_size = 100
    num_of_epochs = 20

    train_dataset = ImageDataset(data,
                                 labels,
                                 csv_path,
                                 lower_bound=0,
                                 upper_bound=6000)
    test_dataset = ImageDataset(data,
                                labels,
                                csv_path,
                                lower_bound=6000,
                                upper_bound=10000)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=True)

    lr = 0.0002
    criterion = nn.MSELoss()
    model = ApproxNet()

    if use_cuda:
        model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_of_epochs):
        # Training
        for i, (_, _, X, _, y_true) in enumerate(train_loader):
            X = Variable(X)
            y_true = Variable(y_true)
            y_predict = model(X)
            print(y_true.size(), y_predict.size())
            model.zero_grad()
            loss = criterion(y_predict, y_true)
            loss.backward()
            optimizer.step()
            print("Epoch: [{}/{}], Stage: [{}/{}], Loss: {}".format(
                epoch + 1, num_of_epochs, i + 1, len(train_loader),
                loss.data[0]))
    #         break
    # Testing
        correct = np.array([])
        predicted = np.array([])
        for (X, y_true) in test_loader:
            X = Variable(X)
            y_out = model(X)
            _, y_pred = torch.max(y_out, 1)
            correct = np.concatenate((correct, y_true.numpy()))
            predicted = np.concatenate((predicted, y_pred.data.numpy()))
        print(correct.shape, predicted.shape)
        print("Epoch: {}, Accuracy: {}%".format(epoch + 1,
                                                f1_score(correct, predicted)))
Ejemplo n.º 20
0
def main(
    fast=False,
    batch_size=None,
    **kwargs,
):

    # CONFIG
    batch_size = batch_size or (4 if fast else 32)
    energy_loss = get_energy_loss(config="consistency_two_path",
                                  mode="standard",
                                  **kwargs)

    # LOGGING
    logger = VisdomLogger("train", env=JOB)

    # DATA LOADING
    video_dataset = ImageDataset(
        files=sorted(
            glob.glob(f"mount/taskonomy_house_tour/original/image*.png"),
            key=lambda x: int(os.path.basename(x)[5:-4])),
        return_tuple=True,
        resize=720,
    )
    video = RealityTask("video",
                        video_dataset, [
                            tasks.rgb,
                        ],
                        batch_size=batch_size,
                        shuffle=False)

    # GRAPHS
    graph_baseline = TaskGraph(tasks=energy_loss.tasks + [video],
                               finetuned=False)
    graph_baseline.compile(torch.optim.Adam,
                           lr=3e-5,
                           weight_decay=2e-6,
                           amsgrad=True)

    graph_finetuned = TaskGraph(tasks=energy_loss.tasks + [video],
                                finetuned=True)
    graph_finetuned.compile(torch.optim.Adam,
                            lr=3e-5,
                            weight_decay=2e-6,
                            amsgrad=True)

    graph_conservative = TaskGraph(tasks=energy_loss.tasks + [video],
                                   finetuned=True)
    graph_conservative.compile(torch.optim.Adam,
                               lr=3e-5,
                               weight_decay=2e-6,
                               amsgrad=True)
    graph_conservative.load_weights(
        f"{MODELS_DIR}/conservative/conservative.pth")

    graph_ood_conservative = TaskGraph(tasks=energy_loss.tasks + [video],
                                       finetuned=True)
    graph_ood_conservative.compile(torch.optim.Adam,
                                   lr=3e-5,
                                   weight_decay=2e-6,
                                   amsgrad=True)
    graph_ood_conservative.load_weights(
        f"{SHARED_DIR}/results_2F_grounded_1percent_gt_twopath_512_256_crop_7/graph_grounded_1percent_gt_twopath.pth"
    )

    graphs = {
        "baseline": graph_baseline,
        "finetuned": graph_finetuned,
        "conservative": graph_conservative,
        "ood_conservative": graph_ood_conservative,
    }

    inv_transform = transforms.ToPILImage()
    data = {key: {"losses": [], "zooms": []} for key in graphs}
    size = 256
    for batch in range(0, 700):

        if batch * batch_size > len(video_dataset.files): break

        frac = (batch * batch_size * 1.0) / len(video_dataset.files)
        if frac < 0.3:
            size = int(256.0 - 128 * frac / 0.3)
        elif frac < 0.5:
            size = int(128.0 + 128 * (frac - 0.3) / 0.2)
        else:
            size = int(256.0 + (720 - 256) * (frac - 0.5) / 0.5)
        print(size)
        # video.reload()
        size = (size // 32) * 32
        print(size)
        video.step()
        video.task_data[tasks.rgb] = resize(
            video.task_data[tasks.rgb].to(DEVICE), size).data
        print(video.task_data[tasks.rgb].shape)

        with torch.no_grad():

            for i, img in enumerate(video.task_data[tasks.rgb]):
                inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                    f"mount/taskonomy_house_tour/distorted/image{batch*batch_size + i}.png"
                )

            for name, graph in graphs.items():
                normals = graph.sample_path([tasks.rgb, tasks.normal],
                                            reality=video)
                normals2 = graph.sample_path(
                    [tasks.rgb, tasks.principal_curvature, tasks.normal],
                    reality=video)

                for i, img in enumerate(normals):
                    energy, _ = tasks.normal.norm(normals[i:(i + 1)],
                                                  normals2[i:(i + 1)])
                    data[name]["losses"] += [energy.data.cpu().numpy().mean()]
                    data[name]["zooms"] += [size]
                    inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                        f"mount/taskonomy_house_tour/normals_{name}/image{batch*batch_size + i}.png"
                    )

                for i, img in enumerate(normals2):
                    inv_transform(img.clamp(min=0, max=1.0).data.cpu()).save(
                        f"mount/taskonomy_house_tour/path2_{name}/image{batch*batch_size + i}.png"
                    )

    pickle.dump(data, open(f"mount/taskonomy_house_tour/data.pkl", 'wb'))
    os.system("bash ~/scaling/scripts/create_vids.sh")
Ejemplo n.º 21
0
             'normal',
             0.02,
             gpu_id=device,
             use_ce=True,
             unet=False),
    # define_G(channels, num_classes, 64, 'batch', False, 'normal', 0.02, gpu_id=device, use_ce=True, unet=False),
    # define_G(channels, num_classes, 64, 'batch', False, 'normal', 0.02, gpu_id=device, use_ce=True, unet=False),
]

for model, path in zip(models, result_folders):
    model.load_state_dict(torch.load(path + 'generator.pt'))

batch_size = 6

val_data_loader = DataLoader(ImageDataset(dataset_dir,
                                          mode='val',
                                          img_size=img_size),
                             batch_size=1,
                             shuffle=True,
                             num_workers=1)

#criterionL1 = nn.L1Loss().to(device)
criterionMSE = nn.MSELoss().to(device)
#criterionCE = nn.CrossEntropyLoss().to(device)
avg_psnr = [0] * len(models)
number_of_indents = [0] * len(models)
success_rate = [0] * len(models)
'''
for i, batch in enumerate(val_data_loader):
    #if i > 10: break
    input, target = batch[0].to(device), batch[1].to(device)
Ejemplo n.º 22
0
def main(args):

    # ================================================
    # Preparation
    # ================================================
    args.data_dir = os.path.expanduser(args.data_dir)
    args.result_dir = os.path.expanduser(args.result_dir)

    if torch.cuda.is_available() == False:
        raise Exception('At least one gpu must be available.')
    if args.num_gpus == 1:
        # train models in a single gpu
        gpu_cn = torch.device('cuda:0')
        gpu_cd = gpu_cn
    else:
        # train models in different two gpus
        gpu_cn = torch.device('cuda:0')
        gpu_cd = torch.device('cuda:1')

    # create result directory (if necessary)
    if os.path.exists(args.result_dir) == False:
        os.makedirs(args.result_dir)
    for s in ['phase_1', 'phase_2', 'phase_3']:
        if os.path.exists(os.path.join(args.result_dir, s)) == False:
            os.makedirs(os.path.join(args.result_dir, s))

    # dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'), trnsfm)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'), trnsfm)
    train_loader = DataLoader(train_dset, batch_size=args.bsize, shuffle=True)

    # compute the mean pixel value of train dataset
    mean_pv = 0.
    imgpaths = train_dset.imgpaths[:min(args.max_mpv_samples, len(train_dset))]
    if args.comp_mpv:
        pbar = tqdm(total=len(imgpaths), desc='computing the mean pixel value')
        for imgpath in imgpaths:
            img = Image.open(imgpath)
            x = np.array(img, dtype=np.float32) / 255.
            mean_pv += x.mean()
            pbar.update()
        mean_pv /= len(imgpaths)
        pbar.close()
    mpv = torch.tensor(mean_pv).to(gpu_cn)

    # save training config
    args_dict = vars(args)
    args_dict['mean_pv'] = mean_pv
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # ================================================
    # Training Phase 1
    # ================================================
    # model & optimizer
    model_cn = CompletionNetwork()
    model_cn = model_cn.to(gpu_cn)
    if args.optimizer == 'adadelta':
        opt_cn = Adadelta(model_cn.parameters())
    else:
        opt_cn = Adam(model_cn.parameters())

    # training
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:

            opt_cn.zero_grad()

            # generate hole area
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            # merge x, mask, and mpv
            msg = 'phase 1 |'
            x = x.to(gpu_cn)
            msk = msk.to(gpu_cn)
            input = x - x * msk + mpv * msk
            output = model_cn(input)

            # optimize
            loss = completion_network_loss(x, output, msk)
            loss.backward()
            opt_cn.step()

            msg += ' train loss: %.5f' % loss.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_1 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_1',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cn.state_dict(),
                        os.path.join(args.result_dir, 'phase_1',
                                     'model_cn_step%d' % pbar.n))

            if pbar.n >= args.steps_1:
                break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    # model, optimizer & criterion
    model_cd = ContextDiscriminator(
        local_input_shape=(3, args.ld_input_size, args.ld_input_size),
        global_input_shape=(3, args.cn_input_size, args.cn_input_size),
    )
    model_cd = model_cd.to(gpu_cd)
    if args.optimizer == 'adadelta':
        opt_cd = Adadelta(model_cd.parameters())
    else:
        opt_cd = Adam(model_cd.parameters())
    criterion_cd = BCELoss()

    # training
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:

            x = x.to(gpu_cn)
            opt_cd.zero_grad()

            # ================================================
            # fake
            # ================================================
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            fake = torch.zeros((len(x), 1)).to(gpu_cd)
            msk = msk.to(gpu_cn)
            input_cn = x - x * msk + mpv * msk
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_fake = criterion_cd(output_fake, fake)

            # ================================================
            # real
            # ================================================
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            real = torch.ones((len(x), 1)).to(gpu_cd)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area)
            input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd))
            output_real = model_cd(input_real)
            loss_real = criterion_cd(output_real, real)

            # ================================================
            # optimize
            # ================================================
            loss = (loss_fake + loss_real) / 2.
            loss.backward()
            opt_cd.step()

            msg = 'phase 2 |'
            msg += ' train loss: %.5f' % loss.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_2 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_2',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cd.state_dict(),
                        os.path.join(args.result_dir, 'phase_2',
                                     'model_cd_step%d' % pbar.n))

            if pbar.n >= args.steps_2:
                break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    # training
    alpha = torch.tensor(args.alpha).to(gpu_cd)
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:

            x = x.to(gpu_cn)

            # ================================================
            # train model_cd
            # ================================================
            opt_cd.zero_grad()

            # fake
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            # create mask
            msk = gen_input_mask(
                shape=x.shape,
                hole_size=(
                    (args.hole_min_w, args.hole_max_w),
                    (args.hole_min_h, args.hole_max_h),
                ),
                hole_area=hole_area,
                max_holes=args.max_holes,
            )

            fake = torch.zeros((len(x), 1)).to(gpu_cd)
            msk = msk.to(gpu_cn)
            input_cn = x - x * msk + mpv * msk
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_cd_1 = criterion_cd(output_fake, fake)

            # real
            hole_area = gen_hole_area(
                size=(args.ld_input_size, args.ld_input_size),
                mask_size=(x.shape[3], x.shape[2]),
            )

            real = torch.ones((len(x), 1)).to(gpu_cd)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area)
            input_real = (input_ld_real.to(gpu_cd), input_gd_real.to(gpu_cd))
            output_real = model_cd(input_real)
            loss_cd_2 = criterion_cd(output_real, real)

            # optimize
            loss_cd = (loss_cd_1 + loss_cd_2) * alpha / 2.
            loss_cd.backward()
            opt_cd.step()

            # ================================================
            # train model_cn
            # ================================================
            opt_cn.zero_grad()

            loss_cn_1 = completion_network_loss(x, output_cn, msk).to(gpu_cd)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area)
            input_fake = (input_ld_fake.to(gpu_cd), input_gd_fake.to(gpu_cd))
            output_fake = model_cd(input_fake)
            loss_cn_2 = criterion_cd(output_fake, real)

            # optimize
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.
            loss_cn.backward()
            opt_cn.step()

            msg = 'phase 3 |'
            msg += ' train loss (cd): %.5f' % loss_cd.cpu()
            msg += ' train loss (cn): %.5f' % loss_cn.cpu()
            pbar.set_description(msg)
            pbar.update()

            # test
            if pbar.n % args.snaperiod_3 == 0:
                with torch.no_grad():

                    x = sample_random_batch(test_dset, batch_size=args.bsize)
                    x = x.to(gpu_cn)
                    input = x - x * msk + mpv * msk
                    output = model_cn(input)
                    completed = poisson_blend(input, output, msk)
                    imgs = torch.cat((input.cpu(), completed.cpu()), dim=0)
                    save_image(imgs,
                               os.path.join(args.result_dir, 'phase_3',
                                            'step%d.png' % pbar.n),
                               nrow=len(x))
                    torch.save(
                        model_cn.state_dict(),
                        os.path.join(args.result_dir, 'phase_3',
                                     'model_cn_step%d' % pbar.n))
                    torch.save(
                        model_cd.state_dict(),
                        os.path.join(args.result_dir, 'phase_3',
                                     'model_cd_step%d' % pbar.n))

            if pbar.n >= args.steps_3:
                break
    pbar.close()
Ejemplo n.º 23
0
        torch.load(folder_model + model_name, map_location=device), )
    G_AB.eval()
    return G_AB


in_shape = (512, 512)
transforms_used = transforms.Compose([
    transforms.Resize(in_shape, Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (.5, .5, .5))
])

data_set = ImageDataset("../data/%s" % dataset_name,
                        transforms_A=None,
                        transforms_B=None,
                        mode="train",
                        unaligned=False,
                        HPC_run=0,
                        Convert_B2_mask=0,
                        channels=output_channels)

cuda = False  # this will definetly work on the cpu if it is false
if cuda:
    cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if cuda else 'cpu')

G_AB = get_GAN_AB_model(path2model, model_name, device)  # load the model

img_id = torch.randint(len(data_set),
                       (1, ))  # getting some image, here index 100
PIL_A_img = data_set[img_id]['A']
PIL_B_img = data_set[img_id]['B']
Ejemplo n.º 24
0
def main():
    # Get training options
    opt = get_opt()

    # Define the networks
    # netG_A: used to transfer image from domain A to domain B
    # netG_B: used to transfer image from domain B to domain A
    netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf,
                                opt.n_res, opt.dropout)
    netG_B = networks.Generator(opt.output_nc, opt.input_nc, opt.ngf,
                                opt.n_res, opt.dropout)
    if opt.u_net:
        netG_A = networks.U_net(opt.input_nc, opt.output_nc, opt.ngf)
        netG_B = networks.U_net(opt.output_nc, opt.input_nc, opt.ngf)

    # netD_A: used to test whether an image is from domain B
    # netD_B: used to test whether an image is from domain A
    netD_A = networks.Discriminator(opt.input_nc, opt.ndf)
    netD_B = networks.Discriminator(opt.output_nc, opt.ndf)

    # Initialize the networks
    if opt.cuda:
        netG_A.cuda()
        netG_B.cuda()
        netD_A.cuda()
        netD_B.cuda()
    utils.init_weight(netG_A)
    utils.init_weight(netG_B)
    utils.init_weight(netD_A)
    utils.init_weight(netD_B)

    if opt.pretrained:
        netG_A.load_state_dict(torch.load('pretrained/netG_A.pth'))
        netG_B.load_state_dict(torch.load('pretrained/netG_B.pth'))
        netD_A.load_state_dict(torch.load('pretrained/netD_A.pth'))
        netD_B.load_state_dict(torch.load('pretrained/netD_B.pth'))

    # Define the loss functions
    criterion_GAN = utils.GANLoss()
    if opt.cuda:
        criterion_GAN.cuda()

    criterion_cycle = torch.nn.L1Loss()
    # Alternatively, can try MSE cycle consistency loss
    #criterion_cycle = torch.nn.MSELoss()
    criterion_identity = torch.nn.L1Loss()

    # Define the optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A.parameters(),
                                                   netG_B.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.beta1, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.beta1, 0.999))

    # Create learning rate schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=utils.Lambda_rule(opt.epoch, opt.n_epochs,
                                    opt.n_epochs_decay).step)

    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batch_size, opt.input_nc, opt.sizeh, opt.sizew)
    input_B = Tensor(opt.batch_size, opt.output_nc, opt.sizeh, opt.sizew)

    # Define two image pools to store generated images
    fake_A_pool = utils.ImagePool()
    fake_B_pool = utils.ImagePool()

    # Define the transform, and load the data
    transform = transforms.Compose([
        transforms.Resize((opt.sizeh, opt.sizew)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, ))
    ])
    dataloader = DataLoader(ImageDataset(opt.rootdir,
                                         transform=transform,
                                         mode='train'),
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.n_cpu)

    # numpy arrays to store the loss of epoch
    loss_G_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)
    loss_D_A_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)
    loss_D_B_array = np.zeros(opt.n_epochs + opt.n_epochs_decay)

    # Training
    for epoch in range(opt.epoch, opt.n_epochs + opt.n_epochs_decay):
        start = time.strftime("%H:%M:%S")
        print("current epoch :", epoch, " start time :", start)
        # Empty list to store the loss of each mini-batch
        loss_G_list = []
        loss_D_A_list = []
        loss_D_B_list = []

        for i, batch in enumerate(dataloader):
            if i % 50 == 1:
                print("current step: ", i)
                current = time.strftime("%H:%M:%S")
                print("current time :", current)
                print("last loss G:", loss_G_list[-1], "last loss D_A",
                      loss_D_A_list[-1], "last loss D_B", loss_D_B_list[-1])
            real_A = input_A.copy_(batch['A'])
            real_B = input_B.copy_(batch['B'])

            # Train the generator
            optimizer_G.zero_grad()

            # Compute fake images and reconstructed images
            fake_B = netG_A(real_A)
            fake_A = netG_B(real_B)

            if opt.identity_loss != 0:
                same_B = netG_A(real_B)
                same_A = netG_B(real_A)

            # discriminators require no gradients when optimizing generators
            utils.set_requires_grad([netD_A, netD_B], False)

            # Identity loss
            if opt.identity_loss != 0:
                loss_identity_A = criterion_identity(
                    same_A, real_A) * opt.identity_loss
                loss_identity_B = criterion_identity(
                    same_B, real_B) * opt.identity_loss

            # GAN loss
            prediction_fake_B = netD_B(fake_B)
            loss_gan_B = criterion_GAN(prediction_fake_B, True)
            prediction_fake_A = netD_A(fake_A)
            loss_gan_A = criterion_GAN(prediction_fake_A, True)

            # Cycle consistent loss
            recA = netG_B(fake_B)
            recB = netG_A(fake_A)
            loss_cycle_A = criterion_cycle(recA, real_A) * opt.cycle_loss
            loss_cycle_B = criterion_cycle(recB, real_B) * opt.cycle_loss

            # total loss without the identity loss
            loss_G = loss_gan_B + loss_gan_A + loss_cycle_A + loss_cycle_B

            if opt.identity_loss != 0:
                loss_G += loss_identity_A + loss_identity_B

            loss_G_list.append(loss_G.item())
            loss_G.backward()
            optimizer_G.step()

            # Train the discriminator
            utils.set_requires_grad([netD_A, netD_B], True)

            # Train the discriminator D_A
            optimizer_D_A.zero_grad()
            # real images
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, True)

            # fake images
            fake_A = fake_A_pool.query(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, False)

            #total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A_list.append(loss_D_A.item())
            loss_D_A.backward()
            optimizer_D_A.step()

            # Train the discriminator D_B
            optimizer_D_B.zero_grad()
            # real images
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, True)

            # fake images
            fake_B = fake_B_pool.query(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, False)

            # total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B_list.append(loss_D_B.item())
            loss_D_B.backward()
            optimizer_D_B.step()

        # Update the learning rate
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints
        torch.save(netG_A.state_dict(), 'model/netG_A.pth')
        torch.save(netG_B.state_dict(), 'model/netG_B.pth')
        torch.save(netD_A.state_dict(), 'model/netD_A.pth')
        torch.save(netD_B.state_dict(), 'model/netD_B.pth')

        # Save other checkpoint information
        checkpoint = {
            'epoch': epoch,
            'optimizer_G': optimizer_G.state_dict(),
            'optimizer_D_A': optimizer_D_A.state_dict(),
            'optimizer_D_B': optimizer_D_B.state_dict(),
            'lr_scheduler_G': lr_scheduler_G.state_dict(),
            'lr_scheduler_D_A': lr_scheduler_D_A.state_dict(),
            'lr_scheduler_D_B': lr_scheduler_D_B.state_dict()
        }
        torch.save(checkpoint, 'model/checkpoint.pth')

        # Update the numpy arrays that record the loss
        loss_G_array[epoch] = sum(loss_G_list) / len(loss_G_list)
        loss_D_A_array[epoch] = sum(loss_D_A_list) / len(loss_D_A_list)
        loss_D_B_array[epoch] = sum(loss_D_B_list) / len(loss_D_B_list)
        np.savetxt('model/loss_G.txt', loss_G_array)
        np.savetxt('model/loss_D_A.txt', loss_D_A_array)
        np.savetxt('model/loss_D_b.txt', loss_D_B_array)

        if epoch % 10 == 9:
            torch.save(netG_A.state_dict(),
                       'model/netG_A' + str(epoch) + '.pth')
            torch.save(netG_B.state_dict(),
                       'model/netG_B' + str(epoch) + '.pth')
            torch.save(netD_A.state_dict(),
                       'model/netD_A' + str(epoch) + '.pth')
            torch.save(netD_B.state_dict(),
                       'model/netD_B' + str(epoch) + '.pth')

        end = time.strftime("%H:%M:%S")
        print("current epoch :", epoch, " end time :", end)
        print("G loss :", loss_G_array[epoch], "D_A loss :",
              loss_D_A_array[epoch], "D_B loss :", loss_D_B_array[epoch])
Ejemplo n.º 25
0
def caculate_fitness(mask_input_A2B, mask_input_B2A, gpu_id, fitness_id,
                     A2B_or_B2A):

    torch.cuda.set_device(gpu_id)
    #print("GPU_ID is%d\n"%(gpu_id))

    model_A2B = Generator(opt.input_nc, opt.output_nc)
    model_B2A = Generator(opt.input_nc, opt.output_nc)

    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    netD_A.cuda(gpu_id)
    netD_B.cuda(gpu_id)
    model_A2B.cuda(gpu_id)
    model_B2A.cuda(gpu_id)

    model_A2B.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))
    model_B2A.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))
    netD_A.load_state_dict(torch.load('/cache/models/netD_A.pth'))
    netD_B.load_state_dict(torch.load('/cache/models/netD_B.pth'))

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    fitness = 0
    cfg_mask_A2B = compute_layer_mask(mask_input_A2B, mask_chns)
    cfg_mask_B2A = compute_layer_mask(mask_input_B2A, mask_chns)
    cfg_full_mask_A2B = [y for x in cfg_mask_A2B for y in x]
    cfg_full_mask_A2B = np.array(cfg_full_mask_A2B)
    cfg_full_mask_B2A = [y for x in cfg_mask_B2A for y in x]
    cfg_full_mask_B2A = np.array(cfg_full_mask_B2A)
    cfg_id = 0
    start_mask = np.ones(3)
    end_mask = cfg_mask_A2B[cfg_id]

    for m in model_A2B.modules():
        if isinstance(m, nn.Conv2d):

            #print("conv2d")
            #print(m.weight.data.shape)
            #out_channels = m.weight.data.shape[0]
            mask = np.ones(m.weight.data.shape)

            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))

            mask[:, idx0.tolist(), :, :] = 0
            mask[idx1.tolist(), :, :, :] = 0
            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)

            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            idx_mask = np.argwhere(np.asarray(np.ones(mask.shape) - mask))

            m.weight.data[:, idx0.tolist(), :, :].requires_grad = False
            m.weight.data[idx1.tolist(), :, :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            if cfg_id < len(cfg_mask):
                end_mask = cfg_mask_A2B[cfg_id]
            continue
        elif isinstance(m, nn.ConvTranspose2d):

            mask = np.ones(m.weight.data.shape)
            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask

            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))

            mask[idx0.tolist(), :, :, :] = 0

            mask[:, idx1.tolist(), :, :] = 0

            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)
            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            m.weight.data[idx0.tolist(), :, :, :].requires_grad = False
            m.weight.data[:, idx1.tolist(), :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            end_mask = cfg_mask_A2B[cfg_id]
            continue

    cfg_id = 0
    start_mask = np.ones(3)
    end_mask = cfg_mask_B2A[cfg_id]

    for m in model_B2A.modules():
        if isinstance(m, nn.Conv2d):

            #print("conv2d")
            #print(m.weight.data.shape)
            #out_channels = m.weight.data.shape[0]
            mask = np.ones(m.weight.data.shape)

            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))

            mask[:, idx0.tolist(), :, :] = 0
            mask[idx1.tolist(), :, :, :] = 0
            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)

            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            idx_mask = np.argwhere(np.asarray(np.ones(mask.shape) - mask))

            m.weight.data[:, idx0.tolist(), :, :].requires_grad = False
            m.weight.data[idx1.tolist(), :, :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            if cfg_id < len(cfg_mask):
                end_mask = cfg_mask_B2A[cfg_id]
            continue
        elif isinstance(m, nn.ConvTranspose2d):

            mask = np.ones(m.weight.data.shape)
            mask_bias = np.ones(m.bias.data.shape)

            cfg_mask_start = np.ones(start_mask.shape) - start_mask
            cfg_mask_end = np.ones(end_mask.shape) - end_mask

            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask_start)))
            idx1 = np.squeeze(np.argwhere(np.asarray(cfg_mask_end)))

            mask[idx0.tolist(), :, :, :] = 0

            mask[:, idx1.tolist(), :, :] = 0

            mask_bias[idx1.tolist()] = 0

            m.weight.data = m.weight.data * torch.FloatTensor(mask).cuda(
                gpu_id)
            m.bias.data = m.bias.data * torch.FloatTensor(mask_bias).cuda(
                gpu_id)

            m.weight.data[idx0.tolist(), :, :, :].requires_grad = False
            m.weight.data[:, idx1.tolist(), :, :].requires_grad = False
            m.bias.data[idx1.tolist()].requires_grad = False

            cfg_id += 1
            start_mask = end_mask
            end_mask = cfg_mask_B2A[cfg_id]
            continue

    # Dataset loader
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    lamda_loss_ID = 5.0
    lamda_loss_G = 1.0
    lamda_loss_cycle = 10.0
    optimizer_G = torch.optim.Adam(itertools.chain(
        filter(lambda p: p.requires_grad, model_A2B.parameters()),
        filter(lambda p: p.requires_grad, model_B2A.parameters())),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    transforms_ = [
        transforms.Resize(int(opt.size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(opt.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True,
                                         mode='train'),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            drop_last=True)

    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = model_A2B(real_B)
            loss_identity_B = criterion_identity(
                same_B, real_B) * lamda_loss_ID  #initial 5.0
            # G_B2A(A) should equal A if real A is fed
            same_A = model_B2A(real_A)
            loss_identity_A = criterion_identity(
                same_A, real_A) * lamda_loss_ID  #initial 5.0

            # GAN loss
            fake_B = model_A2B(real_A)
            pred_fake = netD_B(fake_B)
            loss_GAN_A2B = criterion_GAN(
                pred_fake, target_real) * lamda_loss_G  #initial 1.0

            fake_A = model_B2A(real_B)
            pred_fake = netD_A(fake_A)
            loss_GAN_B2A = criterion_GAN(
                pred_fake, target_real) * lamda_loss_G  #initial 1.0

            # Cycle loss
            recovered_A = model_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(
                recovered_A, real_A) * lamda_loss_cycle  #initial 10.0

            recovered_B = model_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(
                recovered_B, real_B) * lamda_loss_cycle  #initial 10.0

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss_G.backward()

            optimizer_G.step()

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B)
            loss_D_real = criterion_GAN(pred_real, target_real)

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()

    with torch.no_grad():

        transforms_ = [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        dataloader = DataLoader(ImageDataset(opt.dataroot,
                                             transforms_=transforms_,
                                             mode='val'),
                                batch_size=opt.batchSize,
                                shuffle=False,
                                drop_last=True)

        Loss_resemble_G = 0
        if A2B_or_B2A == 'A2B':
            netG_A2B = Generator(opt.output_nc, opt.input_nc)
            netD_B = Discriminator(opt.output_nc)

            netG_A2B.cuda(gpu_id)
            netD_B.cuda(gpu_id)

            model_A2B.eval()
            netD_B.eval()
            netG_A2B.eval()

            netD_B.load_state_dict(torch.load('/cache/models/netD_B.pth'))
            netG_A2B.load_state_dict(torch.load('/cache/models/netG_A2B.pth'))

            for i, batch in enumerate(dataloader):

                real_A = Variable(input_A.copy_(batch['A']))

                fake_B = model_A2B(real_A)
                fake_B_full_model = netG_A2B(real_A)
                recovered_A = model_B2A(fake_B)

                pred_fake = netD_B(fake_B.detach())

                pred_fake_full = netD_B(fake_B_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                cycle_loss = criterion_cycle(recovered_A,
                                             real_A) * lamda_loss_cycle
                Loss_resemble_G = Loss_resemble_G + loss_D_fake + cycle_loss

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask_A2B.shape) -
                cfg_full_mask_A2B) * lambda_prune

            print('A2B')
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask_A2B)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask_A2B.shape) - cfg_full_mask_A2B)))
            print("fitness is %f \n" % (fitness))

            current_fitness_A2B[fitness_id] = fitness.item()

        if A2B_or_B2A == 'B2A':
            netG_B2A = Generator(opt.output_nc, opt.input_nc)
            netD_A = Discriminator(opt.output_nc)

            netG_B2A.cuda(gpu_id)
            netD_A.cuda(gpu_id)

            model_B2A.eval()
            netD_A.eval()
            netG_B2A.eval()

            netD_A.load_state_dict(torch.load('/cache/models/netD_A.pth'))
            netG_B2A.load_state_dict(torch.load('/cache/models/netG_B2A.pth'))

            for i, batch in enumerate(dataloader):

                real_B = Variable(input_B.copy_(batch['B']))

                fake_A = model_B2A(real_B)
                fake_A_full_model = netG_B2A(real_B)
                recovered_B = model_A2B(fake_A)

                pred_fake = netD_A(fake_A.detach())

                pred_fake_full = netD_A(fake_A_full_model.detach())

                loss_D_fake = criterion_GAN(pred_fake.detach(),
                                            pred_fake_full.detach())
                cycle_loss = criterion_cycle(recovered_B,
                                             real_B) * lamda_loss_cycle
                Loss_resemble_G = Loss_resemble_G + loss_D_fake + cycle_loss

                lambda_prune = 0.001

            fitness = 500 / Loss_resemble_G.detach() + sum(
                np.ones(cfg_full_mask_B2A.shape) -
                cfg_full_mask_B2A) * lambda_prune

            print('B2A')
            print("GPU_ID is %d" % (gpu_id))
            print("channel num is: %d" % (sum(cfg_full_mask_B2A)))
            print("Loss_resemble_G is %f prune_loss is %f " %
                  (500 / Loss_resemble_G,
                   sum(np.ones(cfg_full_mask_B2A.shape) - cfg_full_mask_B2A)))
            print("fitness is %f \n" % (fitness))

            current_fitness_B2A[fitness_id] = fitness.item()
Ejemplo n.º 26
0
        im = transform(y[0].cpu())
        im.save("%s/%d.png" % (opt.data_path, i), "PNG")

    for i in range(n_samples):
        _sample_one(i)


# sample_center()
sample_text_layout(n_samples=100)
dataloader = torch.utils.data.DataLoader(
    ImageDataset(
        opt.data_path,
        transforms_=[
            # transforms.Resize(opt.img_size),
            transforms.ToTensor(),
            # transforms.Normalize([0.5], [0.5]),
        ],
        has_x=False,
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# -------------------------------
# Training GAN
# -------------------------------


def train_wgan():
    lambda_gp = 10
Ejemplo n.º 27
0
# Set model's test mode
netG_A2B.eval()
netG_B2A.eval()

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)

# Dataset loader
transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
dataloader = DataLoader(ImageDataset(opt.dataroot,
                                     transforms_=transforms_,
                                     mode='test'),
                        batch_size=opt.batchSize,
                        shuffle=False,
                        num_workers=opt.n_cpu)
###################################

###### Testing######

# Create output dirs if they don't exist
if not os.path.exists(out_path + '/A'):
    os.makedirs(out_path + '/A')
if not os.path.exists(out_path + '/B'):
    os.makedirs(out_path + '/B')

for i, batch in enumerate(dataloader):
Ejemplo n.º 28
0
def main(args):

    # ================================================
    # Preparation
    # ================================================
    args.data_dir = os.path.expanduser(args.data_dir)
    args.result_dir = os.path.expanduser(args.result_dir)
    if args.init_model_cn != None:
        args.init_model_cn = os.path.expanduser(args.init_model_cn)
    if args.init_model_cd != None:
        args.init_model_cd = os.path.expanduser(args.init_model_cd)
    if torch.cuda.is_available() == False:
        raise Exception('At least one gpu must be available.')
    else:
        gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if os.path.exists(args.result_dir) == False:
        os.makedirs(args.result_dir)
    for s in ['phase_1', 'phase_2', 'phase_3']:
        if os.path.exists(os.path.join(args.result_dir, s)) == False:
            os.makedirs(os.path.join(args.result_dir, s))

    # dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'),
                              trnsfm,
                              recursive_search=args.recursive_search)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'),
                             trnsfm,
                             recursive_search=args.recursive_search)
    train_loader = DataLoader(train_dset,
                              batch_size=(args.bsize // args.bdivs),
                              shuffle=True)

    # compute mean pixel value of training dataset
    mpv = np.zeros(shape=(3, ))
    if args.mpv == None:
        pbar = tqdm(total=len(train_dset.imgpaths),
                    desc='computing mean pixel value for training dataset...')
        for imgpath in train_dset.imgpaths:
            img = Image.open(imgpath)
            x = np.array(img, dtype=np.float32) / 255.
            mpv += x.mean(axis=(0, 1))
            pbar.update()
        mpv /= len(train_dset.imgpaths)
        pbar.close()
    else:
        mpv = np.array(args.mpv)

    # save training config
    mpv_json = []
    for i in range(3):
        mpv_json.append(float(mpv[i]))  # convert to json serializable type
    args_dict = vars(args)
    args_dict['mpv'] = mpv_json
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # make mpv & alpha tensor
    mpv = torch.tensor(mpv.astype(np.float32).reshape(1, 3, 1, 1)).to(gpu)
    alpha = torch.tensor(args.alpha).to(gpu)

    # ================================================
    # Training Phase 1
    # ================================================
    model_cn = CompletionNetwork()
    if args.data_parallel:
        model_cn = DataParallel(model_cn)
    if args.init_model_cn != None:
        model_cn.load_state_dict(
            torch.load(args.init_model_cn, map_location='cpu'))
    if args.optimizer == 'adadelta':
        opt_cn = Adadelta(model_cn.parameters())
    else:
        opt_cn = Adam(model_cn.parameters())
    model_cn = model_cn.to(gpu)

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:

            # forward
            x = x.to(gpu)
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=gen_hole_area(
                    (args.ld_input_size, args.ld_input_size),
                    (x.shape[3], x.shape[2])),
                max_holes=args.max_holes,
            ).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input = torch.cat((x_mask, mask), dim=1)
            output = model_cn(input)
            loss = completion_network_loss(x, output, mask)

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_1 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_1',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_1',
                            'model_cn_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                # terminate
                if pbar.n >= args.steps_1:
                    break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    model_cd = ContextDiscriminator(
        local_input_shape=(3, args.ld_input_size, args.ld_input_size),
        global_input_shape=(3, args.cn_input_size, args.cn_input_size),
        arc=args.arc,
    )
    if args.data_parallel:
        model_cd = DataParallel(model_cd)
    if args.init_model_cd != None:
        model_cd.load_state_dict(
            torch.load(args.init_model_cd, map_location='cpu'))
    if args.optimizer == 'adadelta':
        opt_cd = Adadelta(model_cd.parameters())
    else:
        opt_cd = Adam(model_cd.parameters())
    model_cd = model_cd.to(gpu)
    bceloss = BCELoss()

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:

            # fake forward
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd(
                (input_ld_fake.to(gpu), input_gd_fake.to(gpu)))
            loss_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_real = bceloss(output_real, real)

            # reduce
            loss = (loss_fake + loss_real) / 2.

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()
                # update progbar
                pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_2 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_2',
                                               'step%d.png' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_2',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_2:
                    break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:

            # forward model_cd
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)

            # fake forward
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, input_gd_fake))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()

            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()

            # forward model_cn
            loss_cn_1 = completion_network_loss(x, output_cn, mask)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, (input_gd_fake)))
            loss_cn_2 = bceloss(output_fake, real)

            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.

            # backward model_cn
            loss_cn.backward()

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description(
                    'phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()
                # test
                if pbar.n % args.snaperiod_3 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x_mask, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cn_step%d' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_3:
                    break
    pbar.close()
Ejemplo n.º 29
0
# Set model's test mode
# netG_A2B.eval()
# netG_B2A.eval()

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)

# Dataset loader
transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

images_ds = ImageDataset(opt.dataroot, transforms_=transforms_, mode='train')
dataloader = DataLoader(images_ds,
                        batch_size=opt.batchSize,
                        shuffle=False,
                        num_workers=opt.n_cpu)
###################################

###### Testing######

# Create output dirs if they don't exist
if not os.path.exists('datasets/horse2zebra/distilA2B'):
    os.makedirs('datasets/horse2zebra/distilA2B')
if not os.path.exists('datasets/horse2zebra/distilB2A'):
    os.makedirs('datasets/horse2zebra/distilB2A')
if not os.path.exists('datasets/horse2zebra/distilA2B/A'):
    os.makedirs('datasets/horse2zebra/distilA2B/A')
Ejemplo n.º 30
0
def launch(*args, **kwargs):
    batch_size = 16
    training = ImageDataset(
        "/home/users/gkiar/ace_mount/ace_home/data/nv_filtered/", mode="train")
    training_loader = DataLoader(training, batch_size=batch_size)
    train(training_loader)