def loss(self, images, stylized, inplace=True):
        stylized = utils.normalize_batch(stylized, inplace=inplace)
        images = utils.normalize_batch(images, inplace=inplace)

        features_stylized = self.vgg_extractor(stylized)
        features_images = self.vgg_extractor(images)

        content_loss = self.hparams.content_weight * \
            F.mse_loss(features_stylized.relu2_2, features_images.relu2_2)
        # style_weights = [1.0,
        #                  1.0,
        #                  1.2,
        #                  1.4,
        #                  1.4]
        style_weights = [1.0, 1.0, 1.4, 1.0, 1.0]
        style_loss = 0.
        for i, (ft_stylized,
                gm_s) in enumerate(zip(features_stylized, self.gram_style)):
            gm_stylized = utils.gram_matrix(ft_stylized)
            gm_s = gm_s.type_as(ft_stylized)
            c, h, w = gm_stylized.shape
            style_loss += F.mse_loss(gm_stylized, gm_s[:len(images), :, :])
            # style_loss *= style_weights[i] / (c * h * w)
        style_loss *= self.hparams.style_weight

        total_loss = content_loss + style_loss
        return total_loss, content_loss, style_loss
Example #2
0
    def step(engine, batch):

        x, _ = batch
        x = x.to(device)

        n_batch = len(x)

        optimizer.zero_grad()

        y = transformer(x)

        x = utils.normalize_batch(x)
        y = utils.normalize_batch(y)

        features_x = vgg(x)
        features_y = vgg(y)

        content_loss = args.content_weight * mse_loss(features_y.relu2_2,
                                                      features_x.relu2_2)

        style_loss = 0.0
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= args.style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        return {
            "content_loss": content_loss.item(),
            "style_loss": style_loss.item(),
            "total_loss": total_loss.item()
        }
Example #3
0
    def calc_spatial_loss(self, x, y):
        feature_x = self.vgg(utils.normalize_batch(x))
        feature_y = self.vgg(utils.normalize_batch(y))

        self.content_loss += self.l2_criterion(feature_y.relu2_2, feature_x.relu2_2)

        for fy, gram_s in zip(feature_y, self.gram_style):
            gram_y = utils.GramMatrix(fy)
            self.style_loss += self.l2_criterion(gram_y, gram_s[:self.batch_size, :, :])

        self.tv_loss += self.tv_criterion(y)
    def _run_one_epoch(self):
        runtime = time.time() - self._start_time
        print('{:.2f}h Epoch {} starts'.format((runtime / 60.) / 60., self._current_epoch), flush=True)
        agg_content_loss = 0.
        agg_style_loss = 0.
        for i, x in enumerate(self._train_loader):
            self._optimizer.zero_grad()
            
            ## Put batch on selected device and feed it through the transfer
            ## network without normalization
            x = x.to(self._device) ## x = y_c = content target
            y = self._transfer_net(x) ## y_hat

            ## Normalize batch and transfer net output for vgg
            x = utils.normalize_batch(x)
            y = utils.normalize_batch(y)

            ## Extract features with vgg
            features_y = self._vgg(y)
            features_x = self._vgg(x)
            
            ## Losses
            ## todo: Compute Simple Loss Functions
            content_loss = self._content_loss(features_y, features_x)
            style_loss = self._style_loss(features_y)
            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()
            total_loss = content_loss + style_loss
            
            ## Backward and optimization
            total_loss.backward()
            self._optimizer.step()
            if (i + 1) % self._log_interval == 0:
                self._log(i + 1, agg_content_loss, agg_style_loss)
                agg_content_loss, agg_style_loss = 0., 0.
            if (i + 1) % self._save_interval == 0:
                self._save_model('_i{}-{}'.format((i + 1) * self._batch_size, len(self._train_loader) * self._batch_size))
            if i + 1 >= self._batch_iterations:
                break
        runtime = time.time() - self._start_time
        print('{:.2f}h Epoch {} over'.format((runtime / 60.) / 60., self._current_epoch), flush=True)
        self._save_model(final=True)
        print('', flush=True)
Example #5
0
    def calc_style(self):
        style_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])

        style = utils.load_image(self.opt.style_image, size=self.opt.style_size)
        style = style_transform(style).cuda()
        style = style.repeat(self.opt.batch_size, 1, 1, 1)
        feature_style = self.vgg(utils.normalize_batch(style))
        self.gram_style = [utils.GramMatrix(y) for y in feature_style]
 def _compute_gram_from_style(self):
     ## Open the style image
     style_image = Image.open(self._style_image_path)
     
     ## Transform the style image to a full batch on the selected device
     transform_pipeline = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
     style_image_batch = transform_pipeline(style_image)
     style_image_batch = style_image_batch.repeat(self._batch_size, 1, 1, 1).to(self._device)
     
     ## Normalize for vgg input and get features
     features_style = self._vgg(utils.normalize_batch(style_image_batch))
     
     ## Compute the style gram matrix: y_s
     gram_style = [utils.gram_matrix(y) for _, y in features_style.items()]
     return gram_style
    def __init__(self, hparams):
        """
        Pass in parsed ArgumentParser to the model
        :param hparams:
        """
        # init superclass
        super(FastNeuralStyleSystem, self).__init__()
        self.hparams = hparams
        torch.manual_seed(hparams.seed)
        np.random.seed(hparams.seed)

        self.batch_size = hparams.batch_size
        if hparams.model == "hrnet":
            self.style_model = HRNet()
        else:
            self.style_model = TransformerNet()
        self.vgg_extractor = Vgg16(requires_grad=False)

        self.transform = transforms.Compose([
            transforms.Resize(hparams.image_size),
            transforms.CenterCrop(hparams.image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])

        self.style_transform = transforms.Compose([
            transforms.Resize(hparams.image_size),
            transforms.CenterCrop(hparams.image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])

        content_image = utils.load_image(self.hparams.content_image,
                                         scale=self.hparams.content_scale)
        self.content_image = self.style_transform(content_image)

        style = utils.load_image(os.path.join('images', 'style-images',
                                              f'{hparams.style_image}.jpg'),
                                 scale=0.5)
        style = self.style_transform(style).requires_grad_(False)
        self.style_image = style.repeat(hparams.batch_size, 1, 1, 1)

        self.features_style = self.vgg_extractor(
            utils.normalize_batch(self.style_image))
        self.gram_style = [utils.gram_matrix(y) for y in self.features_style]
    def optimize(self):
        import os
        temp_dir = os.path.join(
            self.hparams.output_dir,
            f"{self.hparams.style_image}_{self.hparams.model}_w_{self.hparams.weights}_c_{self.hparams.content_weight}_s_{self.hparams.style_weight}"
        )
        temper_dir = os.path.join(temp_dir, f"steps")
        os.makedirs(temper_dir, exist_ok=True)
        for i in range(1):
            print(torch.cuda.memory_summary(i))

        # torch.cuda.ipc_collect()
        single_opt = torch.optim.Adam(self.parameters(),
                                      lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(single_opt,
                                                    step_size=200,
                                                    gamma=0.9)

        self.content_image = self.content_image.to('cuda').unsqueeze(
            0).requires_grad_(False)
        self.content_image = utils.normalize_batch(self.content_image,
                                                   inplace=True)
        prediction = self.content_image.clone().requires_grad_(True)
        # prediction = torch.randn_like(self.content_image).requires_grad_(True)

        torchvision.utils.save_image(self.content_image,
                                     f"{temp_dir}/0_content.png",
                                     normalize=True)
        torchvision.utils.save_image(self.style_image,
                                     f"{temp_dir}/0_style.png",
                                     normalize=True)
        s = []
        c = []
        total = []
        saved_images = []
        for step in range(30000):
            prediction = self.forward(self.content_image)
            prediction.requires_grad_(True)

            loss_val, content_loss, style_loss = self.loss(self.content_image,
                                                           prediction,
                                                           inplace=False)
            total.append(loss_val.item())
            c.append(content_loss.item())
            s.append(style_loss.item())

            if (step + 1) % 200 == 0:
                print("After %d criterions, learning:" % (step + 1))
                print('Total loss: ', loss_val.item())
                print('Content loss: ', content_loss.item())
                print('Style loss: ', style_loss.item())
                print(prediction.shape)
                print(torch.unique(prediction))
            if (step + 1) % 400 == 0:
                new_filename = os.path.join(
                    temper_dir,
                    f"optim_{(step+1)}_{self.hparams.style_image}.png")
                saved_images.append(new_filename)
                torchvision.utils.save_image(prediction,
                                             new_filename,
                                             normalize=True)

            single_opt.zero_grad()

            loss_val.backward()

            single_opt.step()
            # scheduler.step()
        gif_images = []
        for step_img in saved_images:
            gif_images.append(imageio.imread(step_img))
        imageio.mimsave(os.path.join(temp_dir, '0_optimization.gif'),
                        gif_images)
        torchvision.utils.save_image(prediction,
                                     os.path.join(temp_dir, '0_final.png'),
                                     normalize=True)
Example #9
0
def train():
    train_gpu_id = DC.train_gpu_id
    device = t.device('cuda', train_gpu_id) if DC.use_gpu else t.device('cpu')

    transforms = T.Compose([
      T.Resize(DC.input_size),
      T.CenterCrop(DC.input_size),
      T.ToTensor(),
      T.Lambda(lambda x: x*255)
    ])

    train_dir = DC.train_content_dir
    batch_size = DC.train_batch_size

    train_data = ImageFolder(train_dir, transform=transforms)

    num_train_data = len(train_data)

    train_dataloader = t.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=DC.num_workers,
                                               drop_last=True)
    # transform net
    transformer = TransformerNet()
    if DC.load_model:
        transformer.load_state_dict(
          t.load(DC.load_model, 
                 map_location=lambda storage, loc: storage))

    transformer.to(device)

    # Loss net (vgg16)
    vgg = Vgg16().eval()
    vgg.to(device)

    for param in vgg.parameters():
        param.requires_grad = False

    optimizer = t.optim.Adam(transformer.parameters(), DC.base_lr)

    # Get the data from style image
    ys = utils.get_style_data(DC.style_img)
    ys = ys.to(device)

    # The Gram matrix of the style image
    with t.no_grad():
        features_ys = vgg(ys)

        gram_ys = [utils.gram_matrix(ys) for ys in features_ys]

    # Start training
    train_imgs = 0
    iteration = 0
    for epoch in range(DC.max_epoch):
        for i, (data, label) in tqdm.tqdm(enumerate(train_dataloader)):
            train_imgs += batch_size
            iteration += 1

            optimizer.zero_grad()
         
            # Transformer net
            x = data.to(device)
            y = transformer(x)

            x = utils.normalize_batch(x)
            yc = x
            y = utils.normalize_batch(y)

            features_y = vgg(y)
            features_yc = vgg(yc)

            # Content loss
            content_loss = DC.content_weight * \
                             nn.functional.mse_loss(features_y.relu2_2, 
                                                    features_yc.relu2_2)
#            content_loss = DC.content_weight * \
#                             nn.functional.mse_loss(features_y.relu3_3, 
#                                                    features_yc.relu3_3)

            # Style loss
            style_loss = 0.0
            for ft_y, gm_ys in zip(features_y, gram_ys):
                gm_y = utils.gram_matrix(ft_y)
                
                style_loss += nn.functional.mse_loss(gm_y, 
                                                     gm_ys.expand_as(gm_y))


            style_loss *= DC.style_weight

            # Total loss
            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            if iteration%DC.show_iter == 0: 
                print('\ncontent loss: ', content_loss.data)
                print('style loss: ', style_loss.data)
                print('total loss: ', total_loss.data)
                print()

        t.save(transformer.state_dict(), '{}_style.pth'.format(epoch))
Example #10
0
def train(config):
    device = torch.device('cuda', config.device_id)
    set_seed(config.seed)
    check_paths(config)

    transform = transforms.Compose([
        transforms.Resize(config.image_size),
        transforms.CenterCrop(config.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(config.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size)

    transfer_net = TransferNet().to(device)
    optimizer = optim.Adam(transfer_net.parameters(), lr=config.lr)
    mse_loss = nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(config.style_image, size=config.style_size)
    style = style_transform(style)
    style = style.repeat(config.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(config.epochs):
        transfer_net.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transfer_net(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = config.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= config.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

    transfer_net.eval().cpu()
    save_model_path = config.save_model_path
    torch.save(transfer_net.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #11
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #12
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_folder = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1)

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        style = style.cuda()

    style_v = Variable(style)
    style_v = utils.normalize_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0. 
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(x)
            if args.cuda:
                x = x.cuda()
            
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_contetn_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)
Example #13
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    # 可视化操作
    vis = utils.Visualizer(opt.env)

    # 数据加载
    transfroms = tv.transforms.Compose([
        # 将输入的`PIL.Image`重新改变大小成给定的`size`  `size`是最小边的边长
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        # 转为0-1之间
        tv.transforms.ToTensor(),
        # 转为0-255之间
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    # 封装数据集,并进行数据转化
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    # 数据加载器
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))

    # 损失网络 Vgg16  置为预测模式
    vgg = Vgg16().eval()

    # 优化器(需要训练 风格转化网络的参数)
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据  形状 1*c*h*w, 分布 -2~2(使用预设)
    style = utils.get_style_data(opt.style_path)
    # 可视化风格图:-2 到2 转化为0-1
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:
        transformer.cuda()
        style = style.cuda()
        vgg.cuda()

    # 风格图片的gram矩阵
    style_v = Variable(style, volatile=True)
    # 得到vgg中间四层的结果(用以跟输入图片的输出四层比较,计算损失)
    features_style = vgg(style_v)
    # gram_matrix:输入 b,c,h,w  输出 b,c,c 计算gram矩阵(四层的gram矩阵)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    # 损失统计  仪表盘 用以可视化(每个epoch中的所有batch平均损失)
    # 风格损失
    style_meter = tnt.meter.AverageValueMeter()
    # 内容损失
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        # 仪表盘清零
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            if opt.use_gpu:
                x = x.cuda()
            # x为输入的真实图像
            x = Variable(x)
            # 风格转换后的预测图像为y
            y = transformer(x)
            # 输入: b, ch, h, w   0~255
            # 输出: b, ch, h, w    - 2~2
            # 将x,y范围从0-255转化为-2-2
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            # 返回 四个中间层的特征输出
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss内容损失 只计算relu2_2之间的损失   预测图片与原图在relu2_2中间层比较,计算损失
            # content_weight内容的权重     mse_loss均方误差损失函数
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            # 风格损失取四层的均方误差损失总和
            # features_y:预测图像的四层输出内容    gram_style:风格图像的四层输出的gram_matrix
            # zip将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表
            for ft_y, gm_s in zip(features_y, gram_style):
                # 计算预测图像的四层输出内容的gram_matrix
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight
            # 总损失=风格损失+内容损失
            total_loss = content_loss + style_loss
            # 反向传播
            total_loss.backward()
            # 更新参数
            optimizer.step()

            # 损失平滑  将损失加入仪表盘,以便可视化损失过程
            content_meter.add(content_loss.data[0])
            style_meter.add(style_loss.data[0])
            # 每plot_every次前向传播后可视化
            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # 可视化
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原
                #x,y为[-2,2]还原回[0,1]
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        # 每次epoch完毕后保存visdom和模型
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Example #14
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    train_loader = check_dataset(args)
    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    running_avgs = OrderedDict()

    def step(engine, batch):

        x, _ = batch
        x = x.to(device)

        n_batch = len(x)

        optimizer.zero_grad()

        y = transformer(x)

        x = utils.normalize_batch(x)
        y = utils.normalize_batch(y)

        features_x = vgg(x)
        features_y = vgg(y)

        content_loss = args.content_weight * mse_loss(features_y.relu2_2,
                                                      features_x.relu2_2)

        style_loss = 0.0
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= args.style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        return {
            "content_loss": content_loss.item(),
            "style_loss": style_loss.item(),
            "total_loss": total_loss.item()
        }

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(args.checkpoint_model_dir,
                                         "checkpoint",
                                         n_saved=10,
                                         require_empty=False,
                                         create_dir=True)
    progress_bar = Progbar(loader=train_loader, metrics=running_avgs)

    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED(every=args.checkpoint_interval),
        handler=checkpoint_handler,
        to_save={"net": transformer},
    )
    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                              handler=progress_bar)
    trainer.run(train_loader, max_epochs=args.epochs)
Example #15
0
def fast_train(args):
    """Fast training"""

    device = torch.device("cuda" if args.cuda else "cpu")

    transformer = TransformerNet().to(device)
    if args.model:
        transformer.load_state_dict(torch.load(args.model))
    vgg = Vgg16(requires_grad=False).to(device)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    content_transform = transforms.Compose([
        transforms.Resize(args.content_size),
        transforms.CenterCrop(args.content_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))])
    content_dataset = datasets.ImageFolder(args.content_dataset, content_transform)
    content_loader = DataLoader(content_dataset, 
                                batch_size=args.iter_batch_size, 
                                sampler=InfiniteSamplerWrapper(content_dataset),
                                num_workers=args.n_workers)
    content_loader = iter(content_loader)
    style_transform = transforms.Compose([
            transforms.Resize((args.style_size, args.style_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))])

    style_image = utils.load_image(args.style_image)
    style_image = style_transform(style_image)
    style_image = style_image.unsqueeze(0).to(device)
    features_style = vgg(utils.normalize_batch(style_image.repeat(args.iter_batch_size, 1, 1, 1)))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    if args.only_in:
        optimizer = Adam([param for (name, param) in transformer.named_parameters() if "in" in name], lr=lr)
    else:
        optimizer = Adam(transformer.parameters(), lr=lr)

    for i in trange(args.update_step):
        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))

        transformed = transformer(contents)
        features_transformed = vgg(utils.standardize_batch(transformed))
        loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # save model
    transformer.eval().cpu()
    style_name = os.path.basename(args.style_image).split(".")[0]
    save_model_filename = style_name + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)
Example #16
0
def train(args):
    """Meta train the model"""

    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # first move parameters to GPU
    transformer = TransformerNet().to(device)
    vgg = Vgg16(requires_grad=False).to(device)
    global optimizer
    optimizer = Adam(transformer.parameters(), args.meta_lr)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_loader, style_loader, query_loader = get_data_loader(args)

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    writer = SummaryWriter(args.log_dir)

    for iteration in trange(args.max_iter):
        transformer.train()
        
        # bookkeeping
        # using state_dict causes problems, use named_parameters instead
        all_meta_grads = []
        avg_train_c_loss = 0.0
        avg_train_s_loss = 0.0
        avg_train_loss = 0.0
        avg_eval_c_loss = 0.0
        avg_eval_s_loss = 0.0
        avg_eval_loss = 0.0

        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))
        querys = query_loader.next()[0].to(device)
        features_querys = vgg(utils.normalize_batch(querys))

        # learning rate scheduling
        lr = args.lr / (1.0 + iteration * 2.5e-5)
        meta_lr = args.meta_lr / (1.0 + iteration * 2.5e-5)
        for param_group in optimizer.param_groups:
            param_group['lr'] = meta_lr

        for i in range(args.meta_batch_size):
            # sample a style
            style = style_loader.next()[0].to(device)
            style = style.repeat(args.iter_batch_size, 1, 1, 1)
            features_style = vgg(utils.normalize_batch(style))
            gram_style = [utils.gram_matrix(y) for y in features_style]

            fast_weights = OrderedDict((name, param) for (name, param) in transformer.named_parameters() if re.search(r'in\d+\.', name))
            for j in range(args.meta_step):
                # run forward transformation on contents
                transformed = transformer(contents, fast_weights)

                # compute loss
                features_transformed = vgg(utils.standardize_batch(transformed))
                loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

                # compute grad
                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

                # update fast weights
                fast_weights = OrderedDict((name, param - lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))
            
            avg_train_c_loss += c_loss.item()
            avg_train_s_loss += s_loss.item()
            avg_train_loss += loss.item()

            # run forward transformation on querys
            transformed = transformer(querys, fast_weights)
            
            # compute loss
            features_transformed = vgg(utils.standardize_batch(transformed))
            loss, c_loss, s_loss = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight)
            
            grads = torch.autograd.grad(loss / args.meta_batch_size, transformer.parameters())
            all_meta_grads.append({name: g for ((name, _), g) in zip(transformer.named_parameters(), grads)})

            avg_eval_c_loss += c_loss.item()
            avg_eval_s_loss += s_loss.item()
            avg_eval_loss += loss.item()
        
        writer.add_scalar("Avg_Train_C_Loss", avg_train_c_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Train_S_Loss", avg_train_s_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Train_Loss", avg_train_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_C_Loss", avg_eval_c_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_S_Loss", avg_eval_s_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_Loss", avg_eval_loss / args.meta_batch_size, iteration + 1)

        # compute dummy loss to refresh buffer
        transformed = transformer(querys)
        features_transformed = vgg(utils.standardize_batch(transformed))
        dummy_loss, _, _ = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight)

        meta_updates(transformer, dummy_loss, all_meta_grads)

        if args.checkpoint_model_dir is not None and (iteration + 1) % args.checkpoint_interval == 0:
            transformer.eval().cpu()
            ckpt_model_filename = "iter_" + str(iteration + 1) + ".pth"
            ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
            torch.save(transformer.state_dict(), ckpt_model_path)
            transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "Final_iter_" + str(args.max_iter) + "_" + \
                          str(args.content_weight) + "_" + \
                          str(args.style_weight) + "_" + \
                          str(args.lr) + "_" + \
                          str(args.meta_lr) + "_" + \
                          str(args.meta_batch_size) + "_" + \
                          str(args.meta_step) + "_" + \
                          time.ctime() + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print "Done, trained model saved at {}".format(save_model_path)
Example #17
0
def train():
    device = torch.device("cuda")

    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    ])

    train_dataset = datasets.ImageFolder(dataset_path, transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    if resume_TransformerNet_from_file:
        if os.path.isfile(TransformerNet_path):
            print("=> loading checkpoint '{}'".format(TransformerNet_path))
            TransformerNet_par = torch.load(TransformerNet_path)
            for k in list(TransformerNet_par.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del TransformerNet_par[k]
            transformer.load_state_dict(TransformerNet_par)
            print("=> loaded checkpoint '{}'".format(TransformerNet_path))
        else:
            print("=> no checkpoint found at '{}'".format(TransformerNet_path))

    vgg = Vgg16(requires_grad=False).to(device)
    style = Image.open(style_image_path)
    style = transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    model_fcrn = FCRN_for_transfer(batch_size=batch_size,
                                   requires_grad=False).to(device)
    model_fcrn_par = torch.load(FCRN_path)
    #start_epoch = model_fcrn_par['epoch']
    model_fcrn.load_state_dict(model_fcrn_par['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        FCRN_path, model_fcrn_par['epoch']))

    for e in range(epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_depth_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            depth_y = model_fcrn(y)
            depth_x = model_fcrn(x)

            content_loss = content_weight * mse_loss(features_y.relu2_2,
                                                     features_x.relu2_2)
            depth_loss = depth_weight * mse_loss(depth_y, depth_x)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= style_weight

            total_loss = content_loss + depth_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_depth_loss += depth_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tdepth: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_depth_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if checkpoint_model_dir is not None and (
                    batch_id + 1) % checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            content_weight) + "_" + str(style_weight) + ".model"
    save_model_path = os.path.join(save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #18
0
def train_model():

    torch.backends.cudnn.deterministic = True
    device = torch.device("cuda")

    print("CUDA visible devices: " + str(torch.cuda.device_count()))
    print("CUDA Device Name: " + str(torch.cuda.get_device_name(device)))

    # Creating dataset loaders

    train_dataset = LoadData(dataset_dir, TRAIN_SIZE, dslr_scale, test=False)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1,
                              pin_memory=True,
                              drop_last=True)

    test_dataset = LoadData(dataset_dir, TEST_SIZE, dslr_scale, test=True)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=False)

    visual_dataset = LoadVisualData(dataset_dir, 10, dslr_scale, level)
    visual_loader = DataLoader(dataset=visual_dataset,
                               batch_size=1,
                               shuffle=False,
                               num_workers=0,
                               pin_memory=True,
                               drop_last=False)

    # Creating image processing network and optimizer

    generator = PyNET(level=level,
                      instance_norm=True,
                      instance_norm_level_1=True).to(device)
    generator = torch.nn.DataParallel(generator)

    optimizer = Adam(params=generator.parameters(), lr=learning_rate)

    # Restoring the variables

    if level < 5:
        generator.load_state_dict(
            torch.load("models/pynet_level_" + str(level + 1) + "_epoch_" +
                       str(restore_epoch) + ".pth"),
            strict=False)

    # Losses

    VGG_19 = vgg_19(device)
    MSE_loss = torch.nn.MSELoss()
    MS_SSIM = MSSSIM()

    # Train the network

    for epoch in range(num_train_epochs):

        torch.cuda.empty_cache()

        train_iter = iter(train_loader)
        for i in range(len(train_loader)):

            optimizer.zero_grad()
            x, y = next(train_iter)

            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            enhanced = generator(x)

            # MSE Loss
            loss_mse = MSE_loss(enhanced, y)

            # VGG Loss

            if level < 5:
                enhanced_vgg = VGG_19(normalize_batch(enhanced))
                target_vgg = VGG_19(normalize_batch(y))
                loss_content = MSE_loss(enhanced_vgg, target_vgg)

            # Total Loss

            if level == 5 or level == 4:
                total_loss = loss_mse
            if level == 3 or level == 2:
                total_loss = loss_mse * 10 + loss_content
            if level == 1:
                total_loss = loss_mse * 10 + loss_content
            if level == 0:
                loss_ssim = MS_SSIM(enhanced, y)
                total_loss = loss_mse + loss_content + (1 - loss_ssim) * 0.4

            # Perform the optimization step

            total_loss.backward()
            optimizer.step()

            if i == 0:

                # Save the model that corresponds to the current epoch

                generator.eval().cpu()
                torch.save(
                    generator.state_dict(), "models/pynet_level_" +
                    str(level) + "_epoch_" + str(epoch) + ".pth")
                generator.to(device).train()

                # Save visual results for several test images

                generator.eval()
                with torch.no_grad():

                    visual_iter = iter(visual_loader)
                    for j in range(len(visual_loader)):

                        torch.cuda.empty_cache()

                        raw_image = next(visual_iter)
                        raw_image = raw_image.to(device, non_blocking=True)

                        enhanced = generator(raw_image.detach())
                        enhanced = np.asarray(
                            to_image(torch.squeeze(enhanced.detach().cpu())))

                        imageio.imwrite(
                            "results/pynet_img_" + str(j) + "_level_" +
                            str(level) + "_epoch_" + str(epoch) + ".jpg",
                            enhanced)

                # Evaluate the model

                loss_mse_eval = 0
                loss_psnr_eval = 0
                loss_vgg_eval = 0
                loss_ssim_eval = 0

                generator.eval()
                with torch.no_grad():

                    test_iter = iter(test_loader)
                    for j in range(len(test_loader)):

                        x, y = next(test_iter)
                        x = x.to(device, non_blocking=True)
                        y = y.to(device, non_blocking=True)
                        enhanced = generator(x)

                        loss_mse_temp = MSE_loss(enhanced, y).item()

                        loss_mse_eval += loss_mse_temp
                        loss_psnr_eval += 20 * math.log10(
                            1.0 / math.sqrt(loss_mse_temp))

                        if level < 2:
                            loss_ssim_eval += MS_SSIM(y, enhanced)

                        if level < 5:
                            enhanced_vgg_eval = VGG_19(
                                normalize_batch(enhanced)).detach()
                            target_vgg_eval = VGG_19(
                                normalize_batch(y)).detach()

                            loss_vgg_eval += MSE_loss(enhanced_vgg_eval,
                                                      target_vgg_eval).item()

                loss_mse_eval = loss_mse_eval / TEST_SIZE
                loss_psnr_eval = loss_psnr_eval / TEST_SIZE
                loss_vgg_eval = loss_vgg_eval / TEST_SIZE
                loss_ssim_eval = loss_ssim_eval / TEST_SIZE

                if level < 2:
                    print(
                        "Epoch %d, mse: %.4f, psnr: %.4f, vgg: %.4f, ms-ssim: %.4f"
                        % (epoch, loss_mse_eval, loss_psnr_eval, loss_vgg_eval,
                           loss_ssim_eval))
                elif level < 5:
                    print(
                        "Epoch %d, mse: %.4f, psnr: %.4f, vgg: %.4f" %
                        (epoch, loss_mse_eval, loss_psnr_eval, loss_vgg_eval))
                else:
                    print("Epoch %d, mse: %.4f, psnr: %.4f" %
                          (epoch, loss_mse_eval, loss_psnr_eval))

                generator.train()
Example #19
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(
            args.image_size),  # the shorter side is resize to match image_size
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),  # to tensor [0,1]
        transforms.Lambda(lambda x: x.mul(255))  # convert back to [0, 255]
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)  # to provide a batch loader

    style_image = [f for f in os.listdir(args.style_image)]
    style_num = len(style_image)
    print(style_num)

    transformer = TransformerNet(style_num=style_num).to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.Resize(args.style_size),
        transforms.CenterCrop(args.style_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    style_batch = []

    for i in range(style_num):
        style = utils.load_image(args.style_image + style_image[i],
                                 size=args.style_size)
        style = style_transform(style)
        style_batch.append(style)

    style = torch.stack(style_batch).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)

            if n_batch < args.batch_size:
                break  # skip to next epoch when no enough images left in the last batch of current epoch

            count += n_batch
            optimizer.zero_grad()  # initialize with zero gradients

            batch_style_id = [
                i % style_num for i in range(count - n_batch, count)
            ]
            y = transformer(x.to(device), style_id=batch_style_id)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y.to(device))
            features_x = vgg(x.to(device))
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[batch_style_id, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(
        args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(
            ':', '') + "_" + str(int(args.content_weight)) + "_" + str(
                int(args.style_weight)) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #20
0
def train(style_image, dataset_path):
    print('Training function started...')
    torch.cuda.empty_cache()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_size = 256
    style_weight = 1e10
    content_weight = 1e5
    lr = 1e-3
    batch_size = 3
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    train_dataset = datasets.ImageFolder(dataset_path, transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    transformer = TransformerNet().to(device)

    optimizer = Adam(transformer.parameters(), lr=lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16().to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = load_image(style_image)
    style = style_transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(normalize_batch(style))
    gram_style = [gram_matrix(y) for y in features_style]
    epochs = 2
    print('Starting epochs...')
    for e in range(epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = normalize_batch(y)
            x = normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = content_weight * mse_loss(features_y.relu2_2,
                                                     features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])

            style_loss *= style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()
            log_interval = 2000
            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval().cpu()
    save_model_path = 'models/outpost.pth'
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #21
0
def train_model():
    torch.backends.cudnn.deterministic = True
    device = torch.device("cuda")

    print("CUDA visible devices: " + str(torch.cuda.device_count()))
    print("CUDA Device Name: " + str(torch.cuda.get_device_name(device)))

    # Creating dataset loaders
    train_dataset = LoadTrainData(opt.dataroot, TRAIN_SIZE, test=False)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=1,
                              pin_memory=True,
                              drop_last=True)
    test_dataset = LoadTrainData(opt.dataroot, TEST_SIZE, test=True)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=False)

    # Creating image processing network and optimizer
    generator = MWRCAN().to(device)
    generator = torch.nn.DataParallel(generator)
    #generator.load_state_dict(torch.load('./ckpt/Track1/mwcnnvggssim4_epoch_60.pth'))

    optimizer = Adam(params=generator.parameters(), lr=opt.lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     [50, 100, 150, 200],
                                                     gamma=0.5)

    # Losses
    VGG_19 = vgg_19(device)
    MSE_loss = torch.nn.MSELoss()
    MS_SSIM = MSSSIM()
    L1_loss = torch.nn.L1Loss()

    # Train the network
    for epoch in range(opt.epochs):
        print("lr =  %.8f" % (scheduler.get_lr()[0]))
        torch.cuda.empty_cache()
        generator.to(device).train()
        i = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            enhanced = generator(x)

            loss_l1 = L1_loss(enhanced, y)

            enhanced_vgg = VGG_19(normalize_batch(enhanced))
            target_vgg = VGG_19(normalize_batch(y))
            loss_content = L1_loss(enhanced_vgg, target_vgg)

            loss_ssim = MS_SSIM(enhanced, y)

            total_loss = loss_l1 + loss_content + (1 - loss_ssim) * 0.15
            if i % 100 == 0:
                print(
                    "Epoch %d_%d, L1: %.4f, vgg: %.4f, SSIM: %.4f, total: %.4f"
                    % (epoch, i, loss_l1, loss_content,
                       (1 - loss_ssim) * 0.15, total_loss))
            total_loss.backward()
            optimizer.step()
            i = i + 1
        scheduler.step()

        # Save the model that corresponds to the current epoch
        generator.eval().cpu()
        torch.save(
            generator.state_dict(),
            os.path.join(opt.save_model_path,
                         "mwrcan_epoch_" + str(epoch) + ".pth"))

        # Evaluate the model
        loss_psnr_eval = 0
        generator.to(device)
        generator.eval()
        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                enhanced = generator(x)
                enhanced = torch.clamp(
                    torch.round(enhanced * 255), min=0, max=255) / 255
                y = torch.clamp(torch.round(y * 255), min=0, max=255) / 255
                loss_mse_temp = MSE_loss(enhanced, y).item()
                loss_psnr_eval += 20 * math.log10(
                    1.0 / math.sqrt(loss_mse_temp))
        loss_psnr_eval = loss_psnr_eval / TEST_SIZE
        print("Epoch %d, psnr: %.4f" % (epoch, loss_psnr_eval))
Example #22
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    vis = utils.Visualizer(opt.env)

    # 数据加载
    transfroms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    transformer.to(device)

    # 损失网络 Vgg16
    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    # 风格图片的gram矩阵
    with t.no_grad():
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]

    # 损失统计
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # 损失平滑
            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # 可视化
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        # 保存visdom和模型
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Example #23
0
def train(exp, args):
    device = exp.get_device()
    chrono = exp.chrono()

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    print(memory_size(vgg, batch_size=args.batch_size, input_size=(3, args.image_size, args.image_size)) * 4)

    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.repeat):
        transformer.train()

        with chrono.time('train') as t:
            agg_content_loss = 0.
            agg_style_loss = 0.
            for batch_id, (x, _) in enumerate(train_loader):
                if batch_id > args.number:
                    break

                n_batch = len(x)

                x = x.to(device)
                y = transformer(x)

                y = utils.normalize_batch(y)
                x = utils.normalize_batch(x)

                optimizer.zero_grad()

                features_y = vgg(y)
                features_x = vgg(x)

                content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

                style_loss = 0.
                for ft_y, gm_s in zip(features_y, gram_style):
                    gm_y = utils.gram_matrix(ft_y)
                    style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
                style_loss *= args.style_weight

                total_loss = content_loss + style_loss
                total_loss.backward()

                exp.log_batch_loss(total_loss.item())
                optimizer.step()

                agg_content_loss += content_loss.item()
                agg_style_loss += style_loss.item()

            exp.log_epoch_loss(agg_content_loss + agg_style_loss)

        exp.show_eta(e, t)

    exp.report()
Example #24
0
    def train(self):
        vis = Visualizer()

        optimizer = optim.Adam(self.tfNet.parameters(), self.lr, betas=[0.5, 0.999])
        criterion = nn.MSELoss()

        style = utils.load_image(self.style_path)
        style = self.style_transform(style)
        style = style.repeat(self.batch_size, 1, 1, 1).to(device)
        style = utils.normalize_batch(style)

        features_style = self.vggNet(style)
        gram_style = [utils.gram_matrix(f) for f in features_style]


        start_time = time.time()
        print("Learning started!!!")
        for epoch in range(self.nepochs):
            for step, (content, _) in enumerate(self.dataloader):
                self.tfNet.train()
                step_batch = content.size(0)

                optimizer.zero_grad()

                content = content.to(device)
                output = self.tfNet(content)

                content = utils.normalize_batch(content)
                output = utils.normalize_batch(output)

                output_img = output

                features_content = self.vggNet(content)
                features_output = self.vggNet(output)

                content_loss = self.content_weight * criterion(features_output.relu2_2,
                                                               features_content.relu2_2)

                style_loss = 0.
                for ft_output, gm_style in zip(features_output, gram_style):
                    gm_output = utils.gram_matrix(ft_output)
                    style_loss += criterion(gm_output,
                                            gm_style[:step_batch, :, :])
                style_loss *= self.style_weight

                total_loss = content_loss + style_loss
                total_loss.backward()
                optimizer.step()

                if (step+1) % self.log_interval == 0:
                    end_time = time.time()
                    print("[%d/%d] [%d/%d] time: %f content loss:%.4f style loss:%.4f total loss: %.4f"
                          % (epoch+1, self.nepochs, step+1, len(self.dataloader), end_time - start_time,
                             content_loss.item(), style_loss.item(), total_loss.item()))
                    vis.plot("Content loss per %d steps" % self.log_interval, content_loss.item())
                    vis.plot("Style loss per %d steps" % self.log_interval, style_loss.item())
                    vis.plot("Total loss per %d steps" % self.log_interval, total_loss.item())

            # do save sample images
            if (epoch+1) % self.sample_interval == 0:
                img = output_img.cpu()
                img = img[0]
                utils.save_image("%s/output_epoch%d.png" % (self.sample_folder, epoch + 1), img)

            # do checkpointing
            if (epoch+1) % self.checkpoint_interval == 0:
                self.tfNet.eval()
                torch.save(self.tfNet.state_dict(), "%s/model_epoch%d.pth" % (self.sample_folder, epoch + 1))

        print("Learning finished!!!")
        self.tfNet.eval().cpu()
        torch.save(self.tfNet.state_dict(), "%s/epoch%d_final.pth" % (self.sample_folder, self.nepochs))
        print("Save model complete!!!")
def vgg_style(data, vgg):
    style = vgg(utils.normalize_batch(data))
    gram_style = [utils.gram_matrix(i) for i in style]
    return gram_style
Example #26
0
def train(args):
    # 将torch.Tensor分配到的设备的对象CPU或GPU
    device = torch.device("cuda" if args.cuda else "cpu")
    # 初始化随机种子
    np.random.seed(args.seed)
    # 为CPU设置种子用于生成随机数
    torch.manual_seed(args.seed)
    """
        将多个transform组合起来使用
    """
    transform = transforms.Compose([
        # 重新设定大小
        transforms.Resize(args.image_size),
        # 将给定的Image进行中心切割
        transforms.CenterCrop(args.image_size),
        # 把Image转成张量Tensor格式,大小范围为[0,1]
        transforms.ToTensor(),
        # 使用lambd作为转换器
        transforms.Lambda(lambda x: x.mul(255))
    ])
    # 使用ImageFolder数据加载器,传入数据集的路径
    # transform:一个函数,原始图片作为输入,返回一个转换后的图片
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    # 把上一步做成的数据集放入Data.DataLoader中,可以生成一个迭代器
    # batch_size:int,每个batch加载多少样本
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)
    # 加载模型TransformerNet到设备上
    transformer = TransformerNet().to(device)
    # 我们选择Adam作为优化器
    optimizer = Adam(transformer.parameters(), args.lr)
    # 均方损失函数
    mse_loss = torch.nn.MSELoss()
    # 加载模型Vgg16到设备上
    vgg = Vgg16(requires_grad=False).to(device)
    # 风格图片的处理
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    # 载入风格图片
    style = utils.load_image(args.style_image, size=args.style_size)
    # 处理风格图片
    style = style_transform(style)
    # repeat(*sizes)沿着指定的维度重复tensor
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    # 特征风格归一化
    features_style = vgg(utils.normalize_batch(style))
    # 风格特征图计算Gram矩阵
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # 迭代训练
    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            # 把梯度置零,也就是把loss关于weight的导数变成0
            optimizer.zero_grad()

            y = transformer(x.to(device))

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x.cuda())
            # 计算内容损失
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)
            # 计算风格损失
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight
            # 总损失
            total_loss = content_loss + style_loss
            # 反向传播
            total_loss.backward()
            # 更新参数
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            # 准备打印相关信息,args.log_interval是最开头设置的好了的参数
            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)
            # 生成训练好的风格图片模型 and (batch_id + 1) % args.checkpoint_interval == 0
            if args.checkpoint_model_dir is not None:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #27
0
def train(args):
    if torch.cuda.is_available():
        print('CUDA available, using GPU.')
        device = torch.device('cuda')
    else:
        print('GPU training unavailable... using CPU.')
        device = torch.device('cpu')

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)



    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])


    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # Image transformation network.
    transformer = TransformerNet()

    if args.model:
        state_dict = torch.load(args.model)
        transformer.load_state_dict(state_dict)

    transformer.to(device)

    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    # Loss Network: VGG16
    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            # CUDA if available
            x = x.to(device)

            # Transform image
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            # Feature map of original image
            features_x = vgg(x)
            # Feature Map of transformed image
            features_y = vgg(y)

            # Difference between transformed image, original image.
            # Changed to pull from features_.relu3_3 vs .relu2_2
            content_loss = args.content_weight * mse_loss(features_y.relu3_3, features_x.relu3_3)

            # Compute gram matrix (dot product across each dimension G(4,3) = F4 * F3)
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if True: #(batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #28
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    test_image = utils.load_image(args.test_image)
    test_image = style_transform(test_image)
    test_image = test_image.unsqueeze(0).to(device)

    running_avgs = OrderedDict()
    output_stream = sys.stdout
    alpha = 0.98

    def step(engine, batch):

        x, _ = batch
        x = x.to(device)

        n_batch = len(x)

        transformer.zero_grad()

        y = transformer(x)

        x = utils.normalize_batch(x)
        y = utils.normalize_batch(y)

        features_x = vgg(x)
        features_y = vgg(y)

        content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

        style_loss = 0.
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= args.style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        return {
            'content_loss': content_loss.item(),
            'style_loss': style_loss.item(),
            'total_loss': total_loss.item()
        }

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(args.checkpoint_model_dir, 'ckpt_epoch_',
                                         save_interval=args.checkpoint_interval,
                                         n_saved=10, require_empty=False, create_dir=True)

    @trainer.on(Events.ITERATION_COMPLETED)
    def update_logs(engine):
        for k, v in engine.state.output.items():
            old_v = running_avgs.get(k, v)
            new_v = alpha * old_v + (1 - alpha) * v
            running_avgs[k] = new_v

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):

        num_seen = engine.state.iteration - len(train_loader) * (engine.state.epoch - 1)

        percent_seen = 100 * (float(num_seen / len(train_loader)))
        percentages = list(range(0, 110, 10))

        if int(percent_seen) == 100:
            progress = 0
            equal_to = 10
            sub = 0
        else:
            sub = 1
            progress = 1
            equal_to = np.max(np.where([percent < percent_seen for percent in percentages])[0])

        bar = '[' + '=' * equal_to + '>' * progress + ' ' * (10 - equal_to - sub) + ']'

        message = 'Epoch {epoch} | {percent_seen:.2f}% | {bar}'.format(epoch=engine.state.epoch,
                                                                       percent_seen=percent_seen,
                                                                       bar=bar)
        for key, value in running_avgs.items():
            message += ' | {name}: {value:.2e}'.format(name=key, value=value)

        message += '\r'

        output_stream.write(message)
        output_stream.flush()

    @trainer.on(Events.EPOCH_COMPLETED)
    def complete_progress(engine):
        output_stream.write('\n')

    @trainer.on(Events.EPOCH_COMPLETED)
    def stylize_image(engine):
        path = os.path.join(args.stylized_test_dir, STYLIZED_IMG_FNAME.format(engine.state.epoch))
        content_image = utils.load_image(args.test_image, scale=None)
        content_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])
        content_image = content_transform(content_image)
        content_image = content_image.unsqueeze(0).to(device)

        with torch.no_grad():
            output = transformer(content_image).cpu()

        utils.save_image(path, output[0])

    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
                              to_save={'net': transformer})
    trainer.run(train_loader, max_epochs=args.epochs)
Example #29
0
def train(start_epoch=0):
    np.random.seed(enums.seed)
    torch.manual_seed(enums.seed)

    if enums.cuda:
        torch.cuda.manual_seed(enums.seed)

    transform = transforms.Compose([
        transforms.Resize(enums.image_size),
        transforms.CenterCrop(enums.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(enums.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=enums.batch_size)

    transformer = TransformerNet()
    #transformer = torch.nn.DataParallel(transformer)
    optimizer = Adam(transformer.parameters(), enums.lr)
    if enums.subcommand == 'resume':
        ckpt_state = torch.load(enums.checkpoint_model)
        transformer.load_state_dict(ckpt_state['state_dict'])
        start_epoch = ckpt_state['epoch']
        optimizer.load_state_dict(ckpt_state['optimizer'])

    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    if enums.cuda:
        transformer.cuda()
        vgg.cuda()

    all_style_img_paths = [
        os.path.join(enums.style_image_dir, f)
        for f in os.listdir(enums.style_image_dir)
    ]
    all_style_grams = {}
    for i, style_img in enumerate(all_style_img_paths):
        style = utils.load_image(style_img, size=enums.style_size)
        style = style_transform(style)
        style = style.repeat(
            enums.batch_size, 1, 1,
            1)  # can try with expand but unsure of backprop effects
        if enums.cuda:
            style = style.cuda()
        style_v = Variable(style)
        style_v = utils.normalize_batch(style_v)
        features_style = vgg(style_v)
        gram_style = [utils.gram_matrix(y) for y in features_style]
        all_style_grams[i] = gram_style

    for e in range(start_epoch, enums.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            idx = random.randint(0, enums.num_styles - 1)  # 0 to num_styles-1
            # S = torch.zeros(enums.num_styles, 1) # s,1 vector
            # S[idx] = 1 # one-hot vec for rand chosen style

            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(x)
            if enums.cuda:
                #S = S.cuda()
                x = x.cuda()

            y = transformer(x, idx)
            #print e, batch_id

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)
            gram_style = all_style_grams[idx]

            content_loss = enums.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= enums.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % enums.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)
        # del content_loss, style_loss, S, x, y, style, features_x, features_y

        if enums.checkpoint_model_dir is not None and (
                e + 1) % enums.checkpoint_interval == 0:
            # transformer.eval()
            if enums.cuda:
                transformer.cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e + 1) + ".pth"
            ckpt_model_path = os.path.join(enums.checkpoint_model_dir,
                                           ckpt_model_filename)
            save_checkpoint(
                {
                    'epoch': e + 1,
                    'state_dict': transformer.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, ckpt_model_path)
            if enums.cuda:
                transformer.cuda()
            # transformer.train()

    # save model
    # transformer.eval()
    if enums.cuda:
        transformer.cpu()
    save_model_filename = "epoch_" + str(enums.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            enums.content_weight) + "_" + str(enums.style_weight) + ".model"
    save_model_path = os.path.join(enums.save_model_dir, save_model_filename)
    save_checkpoint(
        {
            'epoch': e + 1,
            'state_dict': transformer.state_dict(),
            'optimizer': optimizer.state_dict()
        }, save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #30
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    vis = utils.Visualizer(opt.env)

    # 数据加载
    transfroms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s))

    # 损失网络 Vgg16
    vgg = Vgg16().eval()

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:
        transformer.cuda()
        style = style.cuda()
        vgg.cuda()

    # 风格图片的gram矩阵
    style_v = Variable(style, volatile=True)
    features_style = vgg(style_v)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    # 损失统计
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            if opt.use_gpu:
                x = x.cuda()
            x = Variable(x)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # 损失平滑
            content_meter.add(content_loss.data[0])
            style_meter.add(style_loss.data[0])

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # 可视化
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原
                vis.img('output', (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))

        # 保存visdom和模型
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Example #31
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #32
0
def train(**kwargs):
    # step1:config
    opt.parse(**kwargs)
    vis = Visualizer(opt.env)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    
    # step2:data
    # dataloader, style_img
    # 这次图片的处理和之前不一样,之前都是normalize,这次改成了lambda表达式乘以255,这种转化之后要给出一个合理的解释
    # 图片共分为两种,一种是原图,一种是风格图片,在作者的代码里,原图用于训练,需要很多,风格图片需要一张,用于损失函数
    
    transforms = T.Compose([
        T.Resize(opt.image_size),
        T.CenterCrop(opt.image_size),
        T.ToTensor(),
        T.Lambda(lambda x: x*255)    
    ])
    # 这次获取图片的方式和第七章一样,仍然是ImageFolder的方式,而不是dataset的方式
    dataset = tv.datasets.ImageFolder(opt.data_root,transform=transforms)
    dataloader = DataLoader(dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers,drop_last=True)
    
    style_img = get_style_data(opt.style_path) # 1*c*H*W
    style_img = style_img.to(device)
    vis.img('style_image',(style_img.data[0]*0.225+0.45).clamp(min=0,max=1)) # 个人觉得这个没必要,下次可以实验一下
    
    # step3: model:Transformer_net 和 损失网络vgg16
    # 整个模型分为两部分,一部分是转化模型TransformerNet,用于转化原始图片,一部分是损失模型Vgg16,用于评价损失函数,
    # 在这里需要注意一下,Vgg16只是用于评价损失函数的,所以它的参数不参与反向传播,只有Transformer的参数参与反向传播,
    # 也就意味着,我们只训练TransformerNet,只保存TransformerNet的参数,Vgg16的参数是在网络设计时就已经加载进去的。
    # Vgg16是以验证model.eval()的方式在运行,表示其中涉及到pooling等层会发生改变
    # 那模型什么时候开始model.eval()呢,之前是是val和test中就会这样设置,那么Vgg16的设置理由是什么?
    # 这里加载模型的时候,作者使用了简单的map_location的记录方法,更轻巧一些
    # 发现作者在写这些的时候越来越趋向方便的方式
    # 在cuda的使用上,模型的cuda是直接使用的,而数据的cuda是在正式训练的时候才使用的,注意一下两者的区别
    # 在第七章作者是通过两种方式实现网络分离的,一种是对于前面网络netg,进行 fake_img = netg(noises).detach(),使得非叶子节点变成一个类似不需要邱求导的叶子节点
    # 第四章还需要重新看,
    
    transformer_net = TransformerNet()
    
    if opt.model_path:
        transformer_net.load_state_dict(t.load(opt.model_path,map_location= lambda _s, _: _s))    
    transformer_net.to(device)
    

    
    # step3: criterion and optimizer
    optimizer = t.optim.Adam(transformer_net.parameters(),opt.lr)
    # 此通过vgg16实现的,损失函数包含两个Gram矩阵和均方误差,所以,此外,我们还需要求Gram矩阵和均方误差
    vgg16 = Vgg16().eval() # 待验证
    vgg16.to(device)
    # vgg的参数不需要倒数,但仍然需要反向传播
    # 回头重新考虑一下detach和requires_grad的区别
    for param in vgg16.parameters():
        param.requires_grad = False
    criterion = t.nn.MSELoss(reduce=True, size_average=True)
    
    
    # step4: meter 损失统计
    style_meter = meter.AverageValueMeter()
    content_meter = meter.AverageValueMeter()
    total_meter = meter.AverageValueMeter()
    
    # step5.2:loss 补充
    # 求style_image的gram矩阵
    # gram_style:list [relu1_2,relu2_2,relu3_3,relu4_3] 每一个是b*c*c大小的tensor
    with t.no_grad():
        features = vgg16(style_img)
        gram_style = [gram_matrix(feature) for feature in features]
    # 损失网络 Vgg16
    # step5: train
    for epoch in range(opt.epoches):
        style_meter.reset()
        content_meter.reset()
        
        # step5.1: train
        for ii,(data,_) in tqdm(enumerate(dataloader)):
            optimizer.zero_grad()
            # 这里作者没有进行 Variable(),与之前不同
            # pytorch 0.4.之后tensor和Variable不再严格区分,创建的tensor就是variable
            # https://mp.weixin.qq.com/s?__biz=MzI0ODcxODk5OA==&mid=2247494701&idx=2&sn=ea8411d66038f172a2f553770adccbec&chksm=e99edfd4dee956c23c47c7bb97a31ee816eb3a0404466c1a57c12948d807c975053e38b18097&scene=21#wechat_redirect
            data = data.to(device)
            y = transformer_net(data)
            # vgg对输入的图片需要进行归一化
            data = normalize_batch(data)
            y = normalize_batch(y)

           
            feature_data = vgg16(data)
            feature_y = vgg16(y) 
            # 疑问??现在的feature是一个什么样子的向量?
            
            # step5.2: loss:content loss and style loss
            # content_loss
            # 在这里和书上的讲的不一样,书上是relu3_3,代码用的是relu2_2
            # https://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medium=referral
            # 均方误差指的是一个像素点的损失,可以理解N*b*h*w个元素加起来,然后除以N*b*h*w
            # 随机梯度下降法本身就是对batch内loss求平均后反向传播
            content_loss = opt.content_weight*criterion(feature_y.relu2_2,feature_data.relu2_2)
            # style loss
            # style loss:relu1_2,relu2_2,relu3_3,relu3_4 
            # 此时需要求每一张图片的gram矩阵
            
            style_loss = 0
            # tensor也可以 for i in tensor:,此时只拆解外面一层的tensor
            # ft_y:b*c*h*w, gm_s:1*c*h*w
            for ft_y, gm_s in zip(feature_y, gram_style):
                gram_y = gram_matrix(ft_y)
                style_loss += criterion(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight
            
            total_loss = content_loss + style_loss
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            #import ipdb
            #ipdb.set_trace()
            # 获取tensor的值 tensor.item()   tensor.tolist()
            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())
            total_meter.add(total_loss.item())
            
            # step5.3: visualize
            if (ii+1)%opt.print_freq == 0 and opt.vis:
                # 为什么总是以这种形式进行debug
                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()
                vis.plot('content_loss',content_meter.value()[0])
                vis.plot('style_loss',style_meter.value()[0])
                vis.plot('total_loss',total_meter.value()[0])
                # 因为现在data和y都已经经过了normalize,变成了-2~2,所以需要把它变回去0-1
                vis.img('input',(data.data*0.225+0.45)[0].clamp(min=0,max=1))
                vis.img('output',(y.data*0.225+0.45)[0].clamp(min=0,max=1))
            
        # step 5.4 save and validate and visualize
        if (epoch+1) % opt.save_every == 0:
            t.save(transformer_net.state_dict(), 'checkpoints/%s_style.pth' % epoch)
            # 保存图片的几种方法,第七章的是 
            # tv.utils.save_image(fix_fake_imgs,'%s/%s.png' % (opt.img_save_path, epoch),normalize=True, range=(-1,1))
            # vis.save竟然没找到  我的神   
            vis.save([opt.env])