Ejemplo n.º 1
0
    def forward(self, img, mask, output, gt):

        loss_dict = {}
        comp = mask * img + (1 - mask) * output

        loss_dict['hole'] = self.l1((1 - mask) * output, (1 - mask) * gt)
        loss_dict['valid'] = self.l1(mask * output, mask * gt)

        if output.size(1) == 3:
            feat_comp = self.extractor(comp)
            feat_output = self.extractor(output)
            feat_gt = self.extractor(gt)
        elif output.size(1) == 1:
            feat_comp = self.extractor(torch.cat([comp] * 3, dim=1))
            feat_output = self.extractor(torch.cat([output] * 3, dim=1))
            feat_gt = self.extractor(torch.cat([gt] * 3, dim=1))
        else:
            raise ValueError('only gray rgb')

        loss_dict['perceptual'] = 0.0
        for i in range(3):
            loss_dict['perceptual'] += self.l1(feat_output[i], feat_gt[i])
            loss_dict['perceptual'] += self.l1(feat_comp[i], feat_gt[i])

        loss_dict['style'] = 0.0
        for i in range(3):
            loss_dict['style'] += self.l1(gram_matrix(feat_output[i]),
                                          gram_matrix(feat_gt[i]))
            loss_dict['style'] += self.l1(gram_matrix(feat_comp[i]),
                                          gram_matrix(feat_gt[i]))

        loss_dict['tv'] = total_variation_loss(comp)

        return loss_dict
Ejemplo n.º 2
0
 def __fit_one(self, link, content_layers, style_grams):
     xp = self.xp
     link.zerograds()
     layers = self.model(link.x)
     if self.keep_color:
         trans_layers = self.model(util.gray(link.x))
     else:
         trans_layers = layers
     loss_info = []
     loss = Variable(xp.zeros((), dtype=np.float32))
     for name, content_layer in content_layers:
         layer = layers[name]
         content_loss = self.content_weight * F.mean_squared_error(layer, Variable(content_layer.data))
         loss_info.append(('content_' + name, float(content_loss.data)))
         loss += content_loss
     for name, style_gram in style_grams:
         gram = util.gram_matrix(trans_layers[name])
         style_loss = self.style_weight * F.mean_squared_error(gram, Variable(style_gram.data))
         loss_info.append(('style_' + name, float(style_loss.data)))
         loss += style_loss
     tv_loss = self.tv_weight * util.total_variation(link.x)
     loss_info.append(('tv', float(tv_loss.data)))
     loss += tv_loss
     loss.backward()
     self.optimizer.update()
     return loss_info
Ejemplo n.º 3
0
def make_style_targets(model, style_input):
    result = model.predict(np.expand_dims(style_input, 0))
    grams = []

    for i, r in enumerate(result):
        var = K.variable(K.eval(gram_matrix(r)))
        grams.append(Input(tensor=var, name=f'style_target_{i}'))

    return grams
Ejemplo n.º 4
0
 def __fit(self, content_image, style_image, epoch_num, callback=None):
     xp = self.xp
     input_image = None
     height, width = content_image.shape[-2:]
     base_epoch = 0
     old_link = None
     for stlide in [4, 2, 1][-self.resolution_num:]:
         if width // stlide < 64:
             continue
         content_x = Variable(xp.asarray(
             content_image[:, :, ::stlide, ::stlide]),
                              volatile=True)
         if self.keep_color:
             style_x = Variable(util.luminance_only(
                 xp.asarray(style_image[:, :, ::stlide, ::stlide]),
                 content_x.data),
                                volatile=True)
         else:
             style_x = Variable(xp.asarray(
                 style_image[:, :, ::stlide, ::stlide]),
                                volatile=True)
         content_layer_names = self.content_layer_names
         content_layers = self.model(content_x)
         content_layers = [(name, content_layers[name])
                           for name in content_layer_names]
         style_layer_names = self.style_layer_names
         style_layers = self.model(style_x)
         style_grams = [(name, util.gram_matrix(style_layers[name]))
                        for name in style_layer_names]
         if input_image is None:
             if self.initial_image == 'content':
                 input_image = xp.asarray(
                     content_image[:, :, ::stlide, ::stlide])
             else:
                 input_image = xp.random.normal(
                     0, 1, size=content_x.data.shape).astype(
                         np.float32) * 0.001
         else:
             input_image = input_image.repeat(2, 2).repeat(2, 3)
             h, w = content_x.data.shape[-2:]
             input_image = input_image[:, :, :h, :w]
         link = chainer.Link(x=input_image.shape)
         if self.device_id >= 0:
             link.to_gpu()
         link.x.data[:] = xp.asarray(input_image)
         self.optimizer.setup(link)
         for epoch in six.moves.range(epoch_num):
             loss_info = self.__fit_one(link, content_layers, style_grams)
             if callback:
                 callback(base_epoch + epoch, link.x, loss_info)
         base_epoch += epoch_num
         input_image = link.x.data
     return link.x
Ejemplo n.º 5
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    # save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
    #     args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_filename = "ckpt.epoch_" + str(args.epochs) + ".ckpt"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Ejemplo n.º 6
0
    def backward_G(self, val=False):
        I_o = self.I_o
        I_g = self.I_g
        L_o = self.L_o
        L_g = self.L_g

        # Has been verfied, for square mask, let D discrinate masked patch, improves the results.
        if self.opt.mask_type == 'center' or self.opt.mask_sub_type == 'rect':
            # Using the cropped I_o as the input of D.
            I_o = self.I_o[:, :, self.rand_t:self.rand_t +
                           self.opt.fineSize // 2 - 2 * self.opt.overlap,
                           self.rand_l:self.rand_l + self.opt.fineSize // 2 -
                           2 * self.opt.overlap]
            I_g = self.I_g[:, :, self.rand_t:self.rand_t +
                           self.opt.fineSize // 2 - 2 * self.opt.overlap,
                           self.rand_l:self.rand_l + self.opt.fineSize // 2 -
                           2 * self.opt.overlap]
            L_o = L_o[:, :, self.rand_t:self.rand_t + self.opt.fineSize // 2 -
                      2 * self.opt.overlap, self.rand_l:self.rand_l +
                      self.opt.fineSize // 2 - 2 * self.opt.overlap]
            L_g = L_g[:, :, self.rand_t:self.rand_t + self.opt.fineSize // 2 -
                      2 * self.opt.overlap, self.rand_l:self.rand_l +
                      self.opt.fineSize // 2 - 2 * self.opt.overlap]

        pred_I_o = self.netD(I_o)
        pred_L_o = self.netD2(L_o)

        self.loss_G_GAN = self.criterionGAN(pred_I_o,
                                            True) * self.opt.gan_weight
        self.loss_G_GAN += self.criterionGAN(pred_L_o,
                                             True) * self.opt.gan_weight

        self.loss_G_L2 = 0
        self.loss_G_L2 += self.criterionL2(self.I_o, self.I_g) * 10
        self.loss_G_L2 += self.criterionL2(self.L_o,
                                           self.L_g) * self.opt.lambda_A

        vgg_ft_I_o = self.vgg16_extractor(I_o)
        vgg_ft_I_g = self.vgg16_extractor(I_g)
        self.loss_style = 0
        self.loss_perceptual = 0
        for i in range(3):
            self.loss_style += self.criterionL2_style_loss(
                util.gram_matrix(vgg_ft_I_o[i]),
                util.gram_matrix(vgg_ft_I_g[i]))
            self.loss_perceptual += self.criterionL2_perceptual_loss(
                vgg_ft_I_o[i], vgg_ft_I_g[i])

        self.loss_style *= self.opt.style_weight
        self.loss_perceptual *= self.opt.content_weight

        self.loss_multi = 0
        for i in range(len(self.I_fea)):
            self.loss_multi += self.criterionL2(self.I_fea[i],
                                                self.I_FEA[i]) * 0.01

        self.loss_G = self.loss_G_L2 + self.loss_G_GAN + self.loss_style + self.loss_perceptual + self.loss_multi

        if val:
            return
        self.loss_G.backward()
Ejemplo n.º 7
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.use_gpu else t.device("cpu")
    vis = util.Visualizer(opt.env)

    transfroms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    transform = TransformerNet()

    if opt.model_path:
        transform.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    transform = transform.to(device)

    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    optimizer = t.optim.Adam(transform.parameters(), opt.lr)

    style = util.get_style_data(opt.style_path)
    vis.img("style", (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    with t.no_grad():
        features_style = vgg(style)
        gram_style = [util.gram_matrix(y) for y in features_style]

    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            x = x.to(device)
            y = t.nn.parallel.data_parallel(transform, x, [0, 1])
            y = util.normalize_batch(y)
            x = util.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = util.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot("content_loss", content_meter.value()[0])
                vis.plot("style_loss", style_meter.value()[0])
                vis.img("output",
                        (y.data.cpu()[0] * 0.255 + 0.45).clamp(min=0, max=1))
                vis.img("input", (x.data.cpu()[0] * 0.255 + 0.45).clamp(min=0,
                                                                        max=1))

        vis.save([opt.env])
        t.save(transform.state_dict(),
               'checkpoints/' + time.ctime() + '%s_style.pth' % epoch)