def optimize(args):
    """    Gatys et al. CVPR 2017
    ref: Image Style Transfer Using Convolutional Neural Networks
    """
    if args.cuda:
        ctx = mx.gpu(0)
    else:
        ctx = mx.cpu(0)
    # load the content and style target
    content_image = utils.tensor_load_rgbimage(args.content_image,
                                               ctx,
                                               size=args.content_size,
                                               keep_asp=True)
    content_image = utils.subtract_imagenet_mean_preprocess_batch(
        content_image)
    style_image = utils.tensor_load_rgbimage(args.style_image,
                                             ctx,
                                             size=args.style_size)
    style_image = utils.subtract_imagenet_mean_preprocess_batch(style_image)
    # load the pre-trained vgg-16 and extract features
    vgg = net.Vgg16()
    utils.init_vgg_params(vgg, 'models', ctx=ctx)
    # content feature
    f_xc_c = vgg(content_image)[1]
    # style feature
    features_style = vgg(style_image)
    gram_style = [net.gram_matrix(y) for y in features_style]
    # output
    output = Parameter('output', shape=content_image.shape)
    output.initialize(ctx=ctx)
    output.set_data(content_image)
    # optimizer
    trainer = gluon.Trainer([output], 'adam', {'learning_rate': args.lr})
    mse_loss = gluon.loss.L2Loss()

    # optimizing the images
    for e in range(args.iters):
        utils.imagenet_clamp_batch(output.data(), 0, 255)
        # fix BN for pre-trained vgg
        with autograd.record():
            features_y = vgg(output.data())
            content_loss = 2 * args.content_weight * mse_loss(
                features_y[1], f_xc_c)
            style_loss = 0.
            for m in range(len(features_y)):
                gram_y = net.gram_matrix(features_y[m])
                gram_s = gram_style[m]
                style_loss = style_loss + 2 * args.style_weight * mse_loss(
                    gram_y, gram_s)
            total_loss = content_loss + style_loss
            total_loss.backward()

        trainer.step(1)
        if (e + 1) % args.log_interval == 0:
            print('loss:{:.2f}'.format(total_loss.asnumpy()[0]))

    # save the image
    output = utils.add_imagenet_mean_batch(output.data())
    utils.tensor_save_bgrimage(output[0], args.output_image, args.cuda)
Exemple #2
0
def test_output():
    style_model = net.Net()
    #print(style_model)
    ctx = mx.cpu(0)
    X = mx.ndarray.random.normal(shape=(20,3,224,224),ctx=ctx)
    vgg = net.Vgg16()
    print(vgg)
    vgg.initialize()

    
    output = vgg.forward(X)
    for item in output:
        print(item.shape)
Exemple #3
0
def train(args):
    np.random.seed(args.seed)
    if args.cuda:
        ctx = mx.gpu(0)
    else:
        ctx = mx.cpu(0)
    # dataloader
    transform = utils.Compose([utils.Scale(args.image_size),
                               utils.CenterCrop(args.image_size),
                               utils.ToTensor(ctx),
                               ])
    train_dataset = data.ImageFolder(args.dataset, transform)
    train_loader = gluon.data.DataLoader(train_dataset, batch_size=args.batch_size, last_batch='discard')
    style_loader = utils.StyleLoader(args.style_folder, args.style_size, ctx=ctx)
    print('len(style_loader):',style_loader.size())
    # models
    vgg = net.Vgg16()
    utils.init_vgg_params(vgg, 'models', ctx=ctx)
    style_model = net.Net(ngf=args.ngf)
    style_model.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
    if args.resume is not None:
        print('Resuming, initializing using weight from {}.'.format(args.resume))
        style_model.collect_params().load(args.resume, ctx=ctx)
    print('style_model:',style_model)
    # optimizer and loss
    trainer = gluon.Trainer(style_model.collect_params(), 'adam',
                            {'learning_rate': args.lr})
    mse_loss = gluon.loss.L2Loss()

    for e in range(args.epochs):
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            # prepare data
            style_image = style_loader.get(batch_id)
            style_v = utils.subtract_imagenet_mean_preprocess_batch(style_image.copy())
            style_image = utils.preprocess_batch(style_image)

            features_style = vgg(style_v)
            gram_style = [net.gram_matrix(y) for y in features_style]

            xc = utils.subtract_imagenet_mean_preprocess_batch(x.copy())
            f_xc_c = vgg(xc)[1]
            with autograd.record():
                style_model.setTarget(style_image)
                y = style_model(x)

                y = utils.subtract_imagenet_mean_batch(y)
                features_y = vgg(y)

                content_loss = 2 * args.content_weight * mse_loss(features_y[1], f_xc_c)

                style_loss = 0.
                for m in range(len(features_y)):
                    gram_y = net.gram_matrix(features_y[m])
                    _, C, _ = gram_style[m].shape
                    gram_s = F.expand_dims(gram_style[m], 0).broadcast_to((args.batch_size, 1, C, C))
                    style_loss = style_loss + 2 * args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])

                total_loss = content_loss + style_loss
                total_loss.backward()
                
            trainer.step(args.batch_size)
            mx.nd.waitall()

            agg_content_loss += content_loss[0]
            agg_style_loss += style_loss[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.3f}\tstyle: {:.3f}\ttotal: {:.3f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                agg_content_loss.asnumpy()[0] / (batch_id + 1),
                                agg_style_loss.asnumpy()[0] / (batch_id + 1),
                                (agg_content_loss + agg_style_loss).asnumpy()[0] / (batch_id + 1)
                )
                print(mesg)

            
            if (batch_id + 1) % (4 * args.log_interval) == 0:
                # save model
                save_model_filename = "Epoch_" + str(e) + "iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".params"
                save_model_path = os.path.join(args.save_model_dir, save_model_filename)
                style_model.collect_params().save(save_model_path)
                print("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".params"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    style_model.collect_params().save(save_model_path)
    print("\nDone, trained model saved at", save_model_path)
Exemple #4
0
 def __init__(self):
     self.vgg_model = net.Vgg16()
     if torch.cuda.is_available():
         print('=> Use CUDA')
epochs = 5000
learning_rate = 0.0001
report_epoch = 100
summary_path = '/media/data_cifs/yuwei/summary'
loss_history = []

generator = ImageGenerator('/media/data_cifs/yuwei/data', batch_size=batch_size, valid_num=1000)
valid_data, valid_label = generator.get_valid()

tf.reset_default_graph()

images = tf.placeholder(tf.float32, [None, 640, 640, 3])
true_out = tf.placeholder(tf.float32, [None, 4])
train_mode = tf.placeholder(tf.bool)

network = net.Vgg16()
network.build(images, train_mode)

with tf.device('/gpu:1'):
sess = tf.Session()

sess.run(tf.global_variables_initializer())

cost = tf.reduce_sum((network.prob - true_out)**2)
train = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
correct = tf.equal(tf.argmax(network.prob, 1), tf.argmax(true_out, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

tf.summary.scalar('cost', cost)
tf.summary.scalar('accuracy', accuracy)
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        nb = opt.batchSize
        size = opt.fineSize
        self.input_A = self.Tensor(nb, opt.input_nc, size, size)
        self.input_B = self.Tensor(nb, opt.output_nc, size, size)

        # load/define networks
        # Code (paper): G_A (G), G_B (F), Vgg, D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.which_model_netG, opt.norm,
                                        not opt.no_dropout, opt.init_type,
                                        self.gpu_ids)
        self.vggNet = net.Vgg16()
        net.init_vgg16(opt.vgg_model_dir)
        print(opt.vgg_model_dir)
        self.vggNet.load_state_dict(
            torch.load(os.path.join(opt.vgg_model_dir, "vgg16.weight")))
        self.vggNet.cuda()

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.init_type,
                                            self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A, 'G_A', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', which_epoch)
                self.load_network(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionContent = torch.nn.MSELoss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=opt.lr,
                                                  betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        networks.print_network(self.netG_B)
        networks.print_network(self.vggNet)

        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')