Esempio n. 1
0
def evaluate(args):
    content_image = utils.tensor_load_rgbimage(args.content_image,
                                               size=args.content_size,
                                               keep_asp=True)
    content_image = content_image.unsqueeze(0)
    style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style = style.unsqueeze(0)
    style = utils.preprocess_batch(style)

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    style_model = Net()
    style_model.load_state_dict(torch.load(args.model))

    if args.cuda:
        style_model.cuda()
        vgg.cuda()
        content_image = content_image.cuda()
        style = style.cuda()

    style_v = Variable(style, volatile=True)
    style_v = utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    content_image = Variable(utils.preprocess_batch(content_image),
                             volatile=True)
    style_model.setTarget(gram_style[2].data)

    output = style_model(content_image)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Esempio n. 2
0
def optimize(args):
    """	Gatys et al. CVPR 2017
	ref: Image Style Transfer Using Convolutional Neural Networks
	"""
    # load the content and style target
    content_image = utils.tensor_load_rgbimage(args.content_image,
                                               size=args.content_size,
                                               keep_asp=True)
    content_image = content_image.unsqueeze(0)
    content_image = Variable(utils.preprocess_batch(content_image),
                             requires_grad=False)
    content_image = utils.subtract_imagenet_mean_batch(content_image)
    style_image = utils.tensor_load_rgbimage(args.style_image,
                                             size=args.style_size)
    style_image = style_image.unsqueeze(0)
    style_image = Variable(utils.preprocess_batch(style_image),
                           requires_grad=False)
    style_image = utils.subtract_imagenet_mean_batch(style_image)

    # load the pre-trained vgg-16 and extract features
    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    if args.cuda:
        content_image = content_image.cuda()
        style_image = style_image.cuda()
        vgg.cuda()
    features_content = vgg(content_image)
    f_xc_c = Variable(features_content[1].data, requires_grad=False)
    features_style = vgg(style_image)
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # init optimizer
    output = Variable(content_image.data, requires_grad=True)
    optimizer = Adam([output], lr=args.lr)
    mse_loss = torch.nn.MSELoss()
    # optimizing the images
    for e in range(args.iters):
        utils.imagenet_clamp_batch(output, 0, 255)
        optimizer.zero_grad()
        features_y = vgg(output)
        content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)

        style_loss = 0.
        for m in range(len(features_y)):
            gram_y = utils.gram_matrix(features_y[m])
            gram_s = Variable(gram_style[m].data, requires_grad=False)
            style_loss += args.style_weight * mse_loss(gram_y, gram_s)

        total_loss = content_loss + style_loss

        if (e + 1) % args.log_interval == 0:
            print(total_loss.data.cpu().numpy()[0])
        total_loss.backward()

        optimizer.step()
    # save the image
    output = utils.add_imagenet_mean_batch(output)
    utils.tensor_save_bgrimage(output.data[0], args.output_image, args.cuda)
Esempio n. 3
0
def init_vgg16(model_folder):
    from torchvision import models
    vgglua = models.vgg16(pretrained=True)
    vgglua.eval()
    vgg = Vgg16()
    for (src, dst) in zip(vgglua.parameters(), vgg.parameters()):
        dst.data[:] = src
    torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight'))
Esempio n. 4
0
def init_vgg16(model_folder):
    """load the vgg16 model feature"""
    if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
        if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')):
            os.system(
                'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O '
                + os.path.join(model_folder, 'vgg16.t7'))
        vgglua = torchfile.load(os.path.join(model_folder, 'vgg16.t7'))
        vgg = Vgg16()
        for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
            dst.data[:] = src
        torch.save(vgg.state_dict(), os.path.join(model_folder,
                                                  'vgg16.weight'))
Esempio n. 5
0
def extract_feats(args):
    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)

    cap = dset.CocoCaptions(
        root='/Pulsar1/Datasets/coco/train2014/train2014',
        annFile=
        '/Neutron9/sahil.c/datasets/annotations/captions_train2014.json',
        transform=transforms.ToTensor())

    print('Number of samples: ', len(cap))
    for i, t in cap:
        image = i.unsqueeze(0)
        image = Variable(utils.preprocess_batch(image), requires_grad=False)
        image = utils.subtract_imagenet_mean_batch(image)
        features_content = vgg(image)
Esempio n. 6
0
#haze_train = haze_train.cuda()
#free_train = free_train.cuda()
netD_A.cuda()
#netD_B.cuda()
netG.cuda()
criterionBCE.cuda()
criterionMSE.cuda()
criterionCycle.cuda()
print('done')

lamdaA = opt.lamdaA
lamdaP = opt.lamdaP

# Initialize VGG-16
vgg = Vgg16()
utils.init_vgg16('./models/')
vgg.load_state_dict(torch.load(os.path.join('./models/', "vgg16.weight")))
vgg.cuda()

# pdb.set_trace()
''' set optimizer'''
#optimizerD = optim.Adam(itertools.chain(netD_A.parameters(),netD_B.parameters()), lr = opt.lrD, betas = (opt.beta1, 0.999))
optimizerD = optim.Adam(netD_A.parameters(),
                        lr=opt.lrD,
                        betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(),
                        lr=opt.lrG,
                        betas=(opt.beta1, 0.999))
'''test data'''
#val_target = torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
def train(args):
    check_paths(args)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    style_model = Net(ngf=args.ngf)
    if args.resume is not None:
        print('Resuming, initializing using weight from {}.'.format(
            args.resume))
        style_model.load_state_dict(torch.load(args.resume))
    print(style_model)
    optimizer = Adam(style_model.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        style_model.cuda()
        vgg.cuda()

    style_loader = StyleLoader(args.style_folder, args.style_size)

    for e in range(args.epochs):
        style_model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()

            style_v = style_loader.get(batch_id)
            style_model.setTarget(style_v)

            style_v = utils.subtract_imagenet_mean_batch(style_v)
            features_style = vgg(style_v)
            gram_style = [utils.gram_matrix(y) for y in features_style]

            y = style_model(x)
            xc = Variable(x.data.clone(), volatile=True)

            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = args.content_weight * mse_loss(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = utils.gram_matrix(features_y[m])
                gram_s = Variable(gram_style[m].data,
                                  requires_grad=False).repeat(
                                      args.batch_size, 1, 1, 1)
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if (batch_id + 1) % (4 * args.log_interval) == 0:
                # save model
                style_model.eval()
                style_model.cpu()
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(
                    count) + "_" + str(time.ctime()).replace(
                        ' ', '_') + "_" + str(args.content_weight) + "_" + str(
                            args.style_weight) + ".model"
                save_model_path = os.path.join(args.save_model_dir,
                                               save_model_filename)
                torch.save(style_model.state_dict(), save_model_path)
                style_model.train()
                style_model.cuda()
                print("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    style_model.eval()
    style_model.cpu()
    save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(style_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 8
0
def train(print_every=10):
    checkpaths(opt)

    train_set = DatasetFromFolder(opt, True)
    test_set = DatasetFromFolder(opt, False)
    training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
    testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=1, shuffle=False)

    norm_layer = get_norm_layer(norm_type='batch')

    netD = NLayerDiscriminator(opt.input_nc, opt.ndf, n_layers=1, norm_layer=norm_layer,use_sigmoid=False, gpu_ids=opt.gpu_ids)
    netG = MyUnetGenerator(opt.input_nc, opt.output_nc, 8, opt.ngf, norm_layer=norm_layer,   use_dropout=False, gpu_ids=opt.gpu_ids)
    netE = MyEncoder(opt.input_nc, opt.output_nc, 8, opt.ngf, norm_layer=norm_layer,use_dropout=False, gpu_ids=opt.gpu_ids)

    netVGG = Vgg16()
    # utils.init_vgg16(opt.model_dir)
    netVGG.load_state_dict(torch.load(os.path.join(opt.model_dir, "vgg16.weight")))

    VGG = make_encoder(model_file=opt.model_vgg)

    perceptual_loss = PerceptualLoss(VGG, 3)


    VGG.cuda()
    netG.cuda()
    netD.cuda()
    netE.cuda()

    netG.apply(weights_init)
    netD.apply(weights_init)
    netE.apply(weights_init)



    criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan)
    criterionL1 = torch.nn.L1Loss()
    mse_loss = torch.nn.MSELoss()
    criterionCEL = nn.CrossEntropyLoss()

    # initialize optimizers
    optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimizer_E = torch.optim.Adam(netE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

    print('=========== Networks initialized ============')
    print_network(netG)
    print_network(netD)
    print('=============================================')

    f = open('./checkpoint/loss.txt', 'w')
    f2 = open('./checkpoint/recognition.txt', 'w')
    strat_time = time.time()
    for epoch in range(1, opt.n_epoch + 1):
        D_running_loss = 0.0
        G_running_loss = 0.0
        G2_running_loss = 0.0

        for (i, batch) in enumerate(training_data_loader, 1):
            real_p, real_s, identity = Variable(batch[0]), Variable(batch[1]), Variable(batch[2].squeeze(1))
            location = batch[3]

            real_p, real_s, identity = real_p.cuda(), real_s.cuda(), identity.cuda()

            optimizer_D.zero_grad()
            # fake
            parsing_feature = netE.forward(real_p[:, 3:, :, :])
            fake_s = netG.forward(real_p[:, 0:3, :, :], parsing_feature)
            fake_ps = torch.cat((fake_s, real_p), 1)
            pred_fake = netD.forward(fake_ps.detach())
            loss_D_fake = criterionGAN(pred_fake, False)
            # real
            real_ps = torch.cat((real_s, real_p), 1)
            pred_real = netD.forward(real_ps)
            loss_D_real = criterionGAN(pred_real, True)

            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            optimizer_D.step()

            optimizer_G.zero_grad()
            optimizer_E.zero_grad()
            pred_fake = netD.forward(fake_ps)
            loss_G_GAN = criterionGAN(pred_fake, True)
            # loss_G_L1 = criterionL1(fake_s, real_s) * opt.lambda1
            # !!!!!!!-------- a2b need modified cirterionL1 -----------------!!!
            loss_global = criterionL1(fake_s, real_s)
            loss_local = localLossL1(fake_s, real_s, real_p, criterionL1)
            loss_G_L1 = opt.alpha1 * loss_global + (1 - opt.alpha1) * loss_local
            loss_G_L1 *= opt.lambda1
            b,c,w,h = fake_s.shape
            yh = fake_s.expand(b,3,w,h)
            ys = real_s.expand(b,3,w,h)
            _mean = Variable(torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).expand_as(yh)).cuda()
            _var = Variable(torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).expand_as(yh)).cuda()
            yh = yh / 2 + 0.5
            ys = ys / 2 + 0.5
            yh = (yh - _mean) / _var
            ys = (ys - _mean) / _var
            loss_recog = perceptual_loss(yh, ys)



            loss_G = loss_G_GAN + loss_G_L1 + opt.styleParam * loss_recog

            loss_G.backward()
            optimizer_G.step()
            optimizer_E.step()

            '''======================================================================='''

            D_running_loss += loss_D.data[0]
            G_running_loss += loss_G.data[0]
            G2_running_loss += loss_G.data[0]
            if i % print_every == 0:
                end_time = time.time()
                time_delta = usedtime(strat_time, end_time)
                print('[%s-%d, %5d] D loss: %.3f ; G loss: %.3f' % (time_delta, epoch, i + 1, D_running_loss / print_every, G_running_loss / print_every))
                f.write('%d,%d,D_loss:%.5f,GAN_loss:%.5f,L1Loss:%.5f\r\n' % (epoch, i + 1, loss_D.data[0], loss_G_GAN.data[0],loss_G_L1.data[0]))
                f2.write('%d,%d,loss_recog_loss:%.5f\r\n' % (epoch, i + 1, loss_recog.data[0]))
                D_running_loss = 0.0
                G_running_loss = 0.0
                G2_running_loss = 0.0
        f.flush()
        f2.flush()
        if epoch >= 500 and epoch % 50 == 0:
            test(epoch, netG, netE, testing_data_loader, opt)

            checkpoint(epoch, netD, netG, netE)
    f.close()
    f2.close()