Example #1
0
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    tensor_normalizer()
])
img_tensor = transform(img).unsqueeze(0)
if torch.cuda.is_available():
    img_tensor = img_tensor.cuda()

img_output = transformer(Variable(img_tensor, volatile=True))
plt.imshow(recover_image(img_tensor.cpu().numpy())[0])

Image.fromarray(recover_image(img_output.data.cpu().numpy())[0])

save_model_path = "model_udnie_imagenet_resnet2.pth"
torch.save(transformer.state_dict(), save_model_path)

transformer.load_state_dict(torch.load(save_model_path))

img = Image.open("content_images/amber.jpg").convert('RGB')
transform = transforms.Compose([transforms.ToTensor(), tensor_normalizer()])
img_tensor = transform(img).unsqueeze(0)
print(img_tensor.size())
if torch.cuda.is_available():
    img_tensor = img_tensor.cuda()

img_output = transformer(Variable(img_tensor, volatile=True))
plt.imshow(recover_image(img_tensor.cpu().numpy())[0])

plt.imshow(recover_image(img_output.data.cpu().numpy())[0])
Example #2
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    # 可视化操作
    vis = utils.Visualizer(opt.env)

    # 数据加载
    transfroms = tv.transforms.Compose([
        # 将输入的`PIL.Image`重新改变大小成给定的`size`  `size`是最小边的边长
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        # 转为0-1之间
        tv.transforms.ToTensor(),
        # 转为0-255之间
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    # 封装数据集,并进行数据转化
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    # 数据加载器
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))

    # 损失网络 Vgg16  置为预测模式
    vgg = Vgg16().eval()

    # 优化器(需要训练 风格转化网络的参数)
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据  形状 1*c*h*w, 分布 -2~2(使用预设)
    style = utils.get_style_data(opt.style_path)
    # 可视化风格图:-2 到2 转化为0-1
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:
        transformer.cuda()
        style = style.cuda()
        vgg.cuda()

    # 风格图片的gram矩阵
    style_v = Variable(style, volatile=True)
    # 得到vgg中间四层的结果(用以跟输入图片的输出四层比较,计算损失)
    features_style = vgg(style_v)
    # gram_matrix:输入 b,c,h,w  输出 b,c,c 计算gram矩阵(四层的gram矩阵)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    # 损失统计  仪表盘 用以可视化(每个epoch中的所有batch平均损失)
    # 风格损失
    style_meter = tnt.meter.AverageValueMeter()
    # 内容损失
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        # 仪表盘清零
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            if opt.use_gpu:
                x = x.cuda()
            # x为输入的真实图像
            x = Variable(x)
            # 风格转换后的预测图像为y
            y = transformer(x)
            # 输入: b, ch, h, w   0~255
            # 输出: b, ch, h, w    - 2~2
            # 将x,y范围从0-255转化为-2-2
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            # 返回 四个中间层的特征输出
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss内容损失 只计算relu2_2之间的损失   预测图片与原图在relu2_2中间层比较,计算损失
            # content_weight内容的权重     mse_loss均方误差损失函数
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            # 风格损失取四层的均方误差损失总和
            # features_y:预测图像的四层输出内容    gram_style:风格图像的四层输出的gram_matrix
            # zip将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表
            for ft_y, gm_s in zip(features_y, gram_style):
                # 计算预测图像的四层输出内容的gram_matrix
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight
            # 总损失=风格损失+内容损失
            total_loss = content_loss + style_loss
            # 反向传播
            total_loss.backward()
            # 更新参数
            optimizer.step()

            # 损失平滑  将损失加入仪表盘,以便可视化损失过程
            content_meter.add(content_loss.data[0])
            style_meter.add(style_loss.data[0])
            # 每plot_every次前向传播后可视化
            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # 可视化
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原
                #x,y为[-2,2]还原回[0,1]
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        # 每次epoch完毕后保存visdom和模型
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Example #3
0
def train(args):
    if torch.cuda.is_available():
        print('CUDA available, using GPU.')
        device = torch.device('cuda')
    else:
        print('GPU training unavailable... using CPU.')
        device = torch.device('cpu')

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)



    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])


    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # Image transformation network.
    transformer = TransformerNet()

    if args.model:
        state_dict = torch.load(args.model)
        transformer.load_state_dict(state_dict)

    transformer.to(device)

    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    # Loss Network: VGG16
    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            # CUDA if available
            x = x.to(device)

            # Transform image
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            # Feature map of original image
            features_x = vgg(x)
            # Feature Map of transformed image
            features_y = vgg(y)

            # Difference between transformed image, original image.
            # Changed to pull from features_.relu3_3 vs .relu2_2
            content_loss = args.content_weight * mse_loss(features_y.relu3_3, features_x.relu3_3)

            # Compute gram matrix (dot product across each dimension G(4,3) = F4 * F3)
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if True: #(batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #4
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    class RGB2YUV(object):
        def __call__(self, img):
            import numpy as np
            import cv2

            npimg = np.array(img)
            yuvnpimg = cv2.cvtColor(npimg, cv2.COLOR_RGB2YUV)
            pilimg = Image.fromarray(yuvnpimg)

            return pilimg

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        RGB2YUV(),
        transforms.ToTensor(),
        # transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    transformer = TransformerNet(in_channels=1,
                                 out_channels=2)  # input: Y, predict: UV
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    # vgg = Vgg16()
    # utils.init_vgg16(args.vgg_model_dir)
    # vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    transformer = nn.DataParallel(transformer)

    if args.cuda:
        if not torch.cuda.is_available():
            raise RuntimeError(
                "CUDA is requested, but related driver/device is not set properly."
            )
        transformer.cuda()

    for e in range(args.epochs):
        transformer.train()
        # agg_content_loss = 0.
        # agg_style_loss = 0.
        count = 0
        for batch_id, (imgs, _) in enumerate(train_loader):
            n_batch = len(imgs)
            count += n_batch
            optimizer.zero_grad()
            # First channel
            x = imgs[:, :1, :, :].clone()
            # Second and third channels
            gt = imgs[:, 1:, :, :].clone()

            if args.cuda:
                x = x.cuda()
                gt = gt.cuda()

            y = transformer(x)

            total_loss = mse_loss(y, gt)
            total_loss.backward()
            optimizer.step()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    total_loss / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    os.makedirs(args.save_model_dir, exist_ok=True)
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #5
0
        style_loss = 0.
        for m in range(len(features_y)):
            gram_s = gram_style[m]
            gram_y = gram_matrix(features_y[m])
            style_loss += args.style_weight * loss(gram_y,
                                                   gram_s.expand_as(gram_y))

        total_loss = content_loss + style_loss + reg_loss
        total_loss.backward()
        optimizer.step()

        agg_content_loss += content_loss.data[0]
        agg_style_loss += style_loss.data[0]
        agg_reg_loss += reg_loss.data[0]

        if (batch_id + 1) % args.log_interval == 0:
            mesg = "[{}/{}] content: {:.6f}  style: {:.6f}  reg: {:.6f}  total: {:.6f}".format(
                count, len(train_dataset), agg_content_loss / count,
                agg_style_loss / count, agg_reg_loss / count,
                (agg_content_loss + agg_style_loss + agg_reg_loss) / count)
            print(mesg)

# save model
transformer.eval()
if torch.cuda.is_available():
    transformer.cpu()

model_file = 'model_' + str(epoch) + '.pth'
torch.save(transformer.state_dict(), model_file)
print('\nSaved model to ' + model_file + '.')
Example #6
0
def fast_train(args):
    """Fast training"""

    device = torch.device("cuda" if args.cuda else "cpu")

    transformer = TransformerNet().to(device)
    if args.model:
        transformer.load_state_dict(torch.load(args.model))
    vgg = Vgg16(requires_grad=False).to(device)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    content_transform = transforms.Compose([
        transforms.Resize(args.content_size),
        transforms.CenterCrop(args.content_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))])
    content_dataset = datasets.ImageFolder(args.content_dataset, content_transform)
    content_loader = DataLoader(content_dataset, 
                                batch_size=args.iter_batch_size, 
                                sampler=InfiniteSamplerWrapper(content_dataset),
                                num_workers=args.n_workers)
    content_loader = iter(content_loader)
    style_transform = transforms.Compose([
            transforms.Resize((args.style_size, args.style_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))])

    style_image = utils.load_image(args.style_image)
    style_image = style_transform(style_image)
    style_image = style_image.unsqueeze(0).to(device)
    features_style = vgg(utils.normalize_batch(style_image.repeat(args.iter_batch_size, 1, 1, 1)))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    if args.only_in:
        optimizer = Adam([param for (name, param) in transformer.named_parameters() if "in" in name], lr=lr)
    else:
        optimizer = Adam(transformer.parameters(), lr=lr)

    for i in trange(args.update_step):
        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))

        transformed = transformer(contents)
        features_transformed = vgg(utils.standardize_batch(transformed))
        loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # save model
    transformer.eval().cpu()
    style_name = os.path.basename(args.style_image).split(".")[0]
    save_model_filename = style_name + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)
Example #7
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    vis = utils.Visualizer(opt.env)

    # 数据加载
    transfroms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    transformer.to(device)

    # 损失网络 Vgg16
    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    # 风格图片的gram矩阵
    with t.no_grad():
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]

    # 损失统计
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # 损失平滑
            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # 可视化
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        # 保存visdom和模型
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Example #8
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    training_set = np.loadtxt(args.dataset, dtype=np.float32)
    training_set_size = training_set.shape[1]
    num_batch = int(training_set_size / args.batch_size)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = np.loadtxt(args.style_image, dtype=np.float32)
    style = style.reshape((1, 1, args.style_size_x, args.style_size_y))
    style = torch.from_numpy(style)
    style = style.repeat(args.batch_size, 3, 1, 1)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    style_v = utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    # Hard data
    if args.hard_data:
        hard_data = np.loadtxt(args.hard_data_file)
        # if not isinstance(hard_data[0], list):
        #     hard_data = [hard_data]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        # for batch_id, (x, _) in enumerate(train_loader):
        for batch_id in range(num_batch):
            x = training_set[:, batch_id * args.batch_size:(batch_id + 1) *
                             args.batch_size]
            n_batch = x.shape[1]
            count += n_batch
            x = x.transpose()
            x = x.reshape((n_batch, 1, args.image_size_x, args.image_size_y))

            # plt.imshow(x[0,:,:,:].squeeze(0))
            # plt.show()
            x = torch.from_numpy(x).float()

            optimizer.zero_grad()

            x = Variable(x)
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            if args.hard_data:
                hard_data_loss = 0
                num_hard_data = 0
                for hd in hard_data:
                    hard_data_loss += args.hard_data_weight * (
                        y[:, 0, hd[1], hd[0]] -
                        hd[2] * 255.0).norm()**2 / n_batch
                    num_hard_data += 1
                hard_data_loss /= num_hard_data

            y = y.repeat(1, 3, 1, 1)
            # x = Variable(utils.preprocess_batch(x))

            # xc = x.data.clone()
            # xc = xc.repeat(1, 3, 1, 1)
            # xc = Variable(xc, volatile=True)

            y = utils.subtract_imagenet_mean_batch(y)
            # xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            # features_xc = vgg(xc)

            # f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            # content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            # total_loss = content_loss + style_loss

            total_loss = style_loss

            if args.hard_data:
                total_loss += hard_data_loss

            total_loss.backward()
            optimizer.step()

            # agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                if args.hard_data:
                    mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\thard_data: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, count, num_batch,
                        agg_content_loss / (batch_id + 1),
                        agg_style_loss / (batch_id + 1),
                        hard_data_loss.data[0],
                        (agg_content_loss + agg_style_loss) / (batch_id + 1))
                else:
                    mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, count, num_batch,
                        agg_content_loss / (batch_id + 1),
                        agg_style_loss / (batch_id + 1),
                        (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # DATA
    # Transform and Dataloader for COCO dataset
    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),  # / 255.
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # MODEL
    # Define Image Transformation Network with MSE loss and Adam optimizer
    transformer = TransformerNet().to(device)
    mse_loss = nn.MSELoss()
    optimizer = optim.Adam(transformer.parameters(), args.learning_rate)

    # Pretrained VGG
    vgg = VGG16(requires_grad=False).to(device)

    # FEATURES
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    # Load the style image
    style = Image.open(args.style)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    # Compute the style features
    features_style = vgg(normalize_batch(style))

    # Loop through VGG style layers to calculate Gram Matrix
    gram_style = [gram_matrix(y) for y in features_style]

    # TRAIN
    for epoch in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.

        for batch_id, (x, _) in tqdm(enumerate(train_loader), unit='batch'):
            x = x.to(device)
            n_batch = len(x)

            optimizer.zero_grad()

            # Parse throught Image Transformation network
            y = transformer(x)
            y = normalize_batch(y)
            x = normalize_batch(x)

            # Parse through VGG layers
            features_y = vgg(y)
            features_x = vgg(x)

            # Calculate content loss
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # Calculate style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            # Monitor
            if (batch_id + 1) % args.log_interval == 0:
                tqdm.write('[{}] ({})\t'
                           'content: {:.6f}\t'
                           'style: {:.6f}\t'
                           'total: {:.6f}'.format(
                               epoch + 1, batch_id + 1,
                               agg_content_loss / (batch_id + 1),
                               agg_style_loss / (batch_id + 1),
                               (agg_content_loss + agg_style_loss) /
                               (batch_id + 1)))

            # Checkpoint
            if (batch_id + 1) % args.save_interval == 0:
                # eval mode
                transformer.eval().cpu()
                style_name = args.style.split('/')[-1].split('.')[0]
                checkpoint_file = os.path.join(args.checkpoint_dir,
                                               '{}.pth'.format(style_name))

                tqdm.write('Checkpoint {}'.format(checkpoint_file))
                torch.save(transformer.state_dict(), checkpoint_file)

                # back to train mode
                transformer.to(device).train()
Example #10
0
def train(args):
    # make sure each time we train, if args.seed stays the same, then
    # the random number we get is same as last time we train.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))  # 0-1 to 0-255
    ])
    # note the order: give where the images at; load the images and transform; give the batch size
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # TODO: in transformernet
    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    # TODO: relus in vgg16
    vgg = Vgg16(requires_grad=False)

    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    # style2 = utils.load_image(args.style_image2, size=args.style_size)
    style = style_transform(style)
    # style2 = style_transform(style2)

    # repeat the style tensor 4 times
    style = style.repeat(args.batch_size, 1, 1, 1)
    # style2 = style2.repeat(args.batch_size, 1, 1, 1)

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        style = style.cuda()
        # style2 = style2.cuda()


    style_v = Variable(style)
    style_v = utils.normalize_batch(style_v)
    features_style = vgg(style_v)
    # style_v2 = Variable(style2)
    # style_v2 = utils.normalize_batch(style_v2)
    # features_style2 = vgg(style_v2)
    # to determine style loss, make use of gram matrix
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # gram_style2 = [utils.gram_matrix(y) for y in features_style2]


    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()  # pytorch accumulates gradients, making them zero for each minibatch
            x = Variable(x)
            if args.cuda:
                x = x.cuda()

            # forward pass
            y = transformer(x)  # after transformer - y

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            # TODO: mse_loss of which relu could be modified
            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
                # style_loss += mse_loss(gm_y, gm_s2[:n_batch, :, :])

            style_loss *= args.style_weight

            total_loss = content_loss + style_loss

            # backward pass
            total_loss.backward()  # this simply computes the gradients for each learnable parameters

            # update weights
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                msg = "Epoch "+str(e + 1)+" "+str(count)+"/"+str(len(train_dataset))
                msg += " content loss : "+str(agg_content_loss / (batch_id + 1))
                msg += " style loss : " +str(agg_style_loss / (batch_id + 1))
                msg += " total loss : " +str((agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(msg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval()
                if args.cuda:
                    transformer.cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                if args.cuda:
                    transformer.cuda()
                transformer.train()

    # save model
    transformer.eval()
    if args.cuda:
        transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #11
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    kwargs = {'num_workers': 0, 'pin_memory': False}

    transform = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = dataset.CustomImageDataset(args.dataset,
                                               transform=transform,
                                               img_size=args.image_size)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    transformer = TransformerNet(args.pad_type)
    transformer = transformer.train()
    optimizer = torch.optim.Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()
    #print(transformer)
    vgg = Vgg16()
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))
    vgg.eval()

    transformer = transformer.cuda()
    vgg = vgg.cuda()

    style = utils.tensor_load_resize(args.style_image, args.style_size)
    style = style.unsqueeze(0)
    print("=> Style image size: " + str(style.size()))

    #(1, H, W, C)
    style = utils.preprocess_batch(style).cuda()
    utils.tensor_save_bgrimage(
        style[0].detach(), os.path.join(args.save_model_dir,
                                        'train_style.jpg'), True)
    style = utils.subtract_imagenet_mean_batch(style)
    features_style = vgg(style)
    gram_style = [utils.gram_matrix(y).detach() for y in features_style]

    for e in range(args.epochs):
        train_loader.dataset.reset()
        agg_content_loss = 0.
        agg_style_loss = 0.
        iters = 0
        for batch_id, (x, _) in enumerate(train_loader):
            if x.size(0) != args.batch_size:
                print("=> Skip incomplete batch")
                continue
            iters += 1

            optimizer.zero_grad()
            x = utils.preprocess_batch(x).cuda()
            y = transformer(x)

            if (batch_id + 1) % 1000 == 0:
                idx = (batch_id + 1) // 1000
                utils.tensor_save_bgrimage(
                    y.data[0],
                    os.path.join(args.save_model_dir, "out_%d.png" % idx),
                    True)
                utils.tensor_save_bgrimage(
                    x.data[0],
                    os.path.join(args.save_model_dir, "in_%d.png" % idx), True)

            y = utils.subtract_imagenet_mean_batch(y)
            x = utils.subtract_imagenet_mean_batch(x)

            features_y = vgg(y)
            features_x = vgg(center_crop(x, y.size(2), y.size(3)))

            #content target
            f_x = features_x[2].detach()
            # content
            f_y = features_y[2]

            content_loss = args.content_weight * mse_loss(f_y, f_x)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = gram_style[m]
                gram_y = utils.gram_matrix(features_y[m])
                batch_style_loss = 0
                for n in range(gram_y.shape[0]):
                    batch_style_loss += args.style_weight * mse_loss(
                        gram_y[n], gram_s[0])
                style_loss += batch_style_loss / gram_y.shape[0]

            total_loss = content_loss + style_loss

            total_loss.backward()
            optimizer.step()
            agg_content_loss += content_loss.data
            agg_style_loss += style_loss.data

            mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                time.ctime(), e + 1, batch_id + 1, len(train_loader),
                agg_content_loss / iters, agg_style_loss / iters,
                (agg_content_loss + agg_style_loss) / iters)
            print(mesg)
            agg_content_loss = agg_style_loss = 0.0
            iters = 0

        # save model
        save_model_filename = "epoch_" + str(e) + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
        save_model_path = os.path.join(args.save_model_dir,
                                       save_model_filename)
        torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #12
0
def train(args):
    if not os.path.exists(args.save_model_dir):
        os.makedirs(args.save_model_dir)


    device = torch.device("cuda" if args.is_cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    # print(style.size)
    # ss('yo')
    style = style_transform(style)  # it's not transform
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    # style = style.repeat(2,1,1,1).to(device)
    # print(style.shape)
    # print()
    # ss('ho')
    features_style = vgg(utils.normalize_batch(style))
    # print(features_style.relu4_3.shape)
    # for i in features_style:
    #     print(i.shape)
    # ss('normalize')
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # for i in gram_style:
        # print(i.shape)
    # ss('main: gram style')
    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            # print(n_batch)
            # ss('hi')
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            # print(x.shape)
            # print(x[0,0,0,:])
            # ss('in epoch, batch')
            y = transformer(x)
            # ss('in epoch, batch')
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            # if (batch_id + 1) % args.log_interval == 0:
            if True:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)
            if args.is_quickrun:
                if count > 10:
                    break
            # if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
            #     transformer.eval().cpu()
            #     ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
            #     ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
            #     torch.save(transformer.state_dict(), ckpt_model_path)
            #     transformer.to(device).train()
            if (e % 50 == 0) or (e>400 and e % 10 ==0):
                # utils.save_image(args.save_model_dir+'/imgs/npepoch_{}.png'.format(e), y[0].detach().cpu())
                # torchvision.utils.save_image(y, './imgs/epoch_{}.png'.format(e), normalize=True)
                torchvision.utils.save_image(y, './imgs/before/epoch_{}.png'.format(e), normalize=True)
                y = y.clamp(0, 255)
                torchvision.utils.save_image(y, './imgs/non/epoch_{}.png'.format(e))
                torchvision.utils.save_image(y, './imgs/after/epoch_{}.png'.format(e), normalize=True)
            # ss('yo')
    # save model
    transformer.eval().cpu()

    save_model_filename = "style_"+args.style_name+"_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #13
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("Loading data")
    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    print "Building the model"
    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1)

    if args.cuda:
        transformer.cuda()
        vgg.cuda()
        style = style.cuda()

    style_v = Variable(style)
    style_v = utils.normalize_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    def multiply(loss, weight):
        return loss * weight

    def add(loss1, loss2):
        return loss1 + loss2

    metrics_names = ['Content Loss', 'Style Loss', 'Total Loss']
    with missinglink_project.create_experiment(
        transformer,
        display_name='Style Transfer PyTorch',
        optimizer=optimizer,
        train_data_object=train_loader,
        metrics={metrics_names[0]: multiply, metrics_names[1]: multiply, metrics_names[2]: add}
    ) as experiment:
        (wrapped_content_loss,
         wrapped_style_loss,
         wrapped_total_loss) = [experiment.metrics[metric_name] for metric_name in metrics_names]

        print("Starting to train")
        for e in experiment.epoch_loop(args.epochs):
            transformer.train()
            agg_content_loss = 0.
            agg_style_loss = 0.
            count = 0
            for batch_id, (x, _) in experiment.batch_loop(iterable=train_loader):
                n_batch = len(x)
                count += n_batch
                optimizer.zero_grad()
                x = Variable(x)
                if args.cuda:
                    x = x.cuda()

                y = transformer(x)

                y = utils.normalize_batch(y)
                x = utils.normalize_batch(x)

                features_y = vgg(y)
                features_x = vgg(x)

                content_loss = mse_loss(features_y.relu2_2, features_x.relu2_2)
                content_loss = wrapped_content_loss(content_loss, args.content_weight)

                style_loss = 0.
                for ft_y, gm_s in zip(features_y, gram_style):
                    gm_y = utils.gram_matrix(ft_y)
                    style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
                style_loss = wrapped_style_loss(style_loss, args.style_weight)

                total_loss = wrapped_total_loss(content_loss, style_loss)
                total_loss.backward()
                optimizer.step()

                agg_content_loss += content_loss.data[0]
                agg_style_loss += style_loss.data[0]

                if (batch_id + 1) % args.log_interval == 0:
                    mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), e + 1, count, len(train_dataset),
                                      agg_content_loss / (batch_id + 1),
                                      agg_style_loss / (batch_id + 1),
                                      (agg_content_loss + agg_style_loss) / (batch_id + 1)
                    )
                    print(mesg)

                if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                    transformer.eval()
                    if args.cuda:
                        transformer.cpu()
                    ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                    ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                    torch.save(transformer.state_dict(), ckpt_model_path)
                    if args.cuda:
                        transformer.cuda()
                    transformer.train()

        # save model
        transformer.eval()
        if args.cuda:
            transformer.cpu()
        save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
        save_model_path = os.path.join(args.save_model_dir, save_model_filename)
        torch.save(transformer.state_dict(), save_model_path)

        print("\nDone, trained model saved at", save_model_path)
Example #14
0
def train(**kwargs):

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    if opt.vis is True:
        from visualize import Visualizer
        vis = Visualizer(opt.env)

    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),  #change value to (0,1)
        tv.transforms.Lambda(lambda x: x * 255)
    ])  #change value to (0,255)
    dataset = tv.datasets.ImageFolder(opt.data_root, transforms)

    dataloader = data.DataLoader(dataset, opt.batch_size)  #value is (0,255)

    transformer = TransformerNet()

    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))

    vgg = VGG16().eval()
    for param in vgg.parameters():
        param.requires_grad = False

    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:

        transformer.cuda()
        style = style.cuda()
        vgg.cuda()

    style_v = Variable(style.unsqueeze(0), volatile=True)
    features_style = vgg(style_v)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            optimizer.zero_grad()
            if opt.use_gpu:
                x = x.cuda()  #(0,255)
            x = Variable(x)
            y = transformer(x)  #(0,255)
            y = utils.normalize_batch(y)  #(-2,2)
            x = utils.normalize_batch(x)  #(-2,2)

            features_y = vgg(y)
            features_x = vgg(x)

            #calculate the content loss: it's only used relu2_2
            # i think should add more layer's result to calculate the result like: w1*relu2_2+w2*relu3_2+w3*relu3_3+w4*relu4_3
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)
            content_meter.add(content_loss.data)

            style_loss = 0
            for ft_y, gm_s in zip(features_y, gram_style):

                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_meter.add(style_loss.data)

            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            if (ii + 1) % (opt.plot_every) == 0:

                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])

                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Example #15
0
def train(DC):
    train_gpu_id = DC.train_gpu_id
    device = t.device('cuda', train_gpu_id) if DC.use_gpu else t.device('cpu')

    input_size = DC.input_size
    super_resol_factor = DC.super_resol_factor

    high_transforms = T.Compose([
        T.Resize(input_size),
        T.CenterCrop(input_size),
        T.ToTensor(),
        T.Lambda(lambda x: x * 255)
    ])
    low_transforms = T.Compose([
        T.Resize(int(input_size / super_resol_factor)),
        T.CenterCrop(int(input_size / super_resol_factor)),
        T.ToTensor(),
        T.Lambda(lambda x: x * 255)
    ])

    HighResol_dir = DC.HighResol_dir
    LowResol_dir = DC.LowResol_dir
    batch_size = DC.train_batch_size

    HighResol_data = ImageFolder(HighResol_dir, transform=high_transforms)
    LowResol_data = ImageFolder(LowResol_dir, transform=low_transforms)

    num_train_data = len(HighResol_data)

    HighResol_dataloader = t.utils.data.DataLoader(HighResol_data,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   num_workers=DC.num_workers,
                                                   drop_last=True)

    LowResol_dataloader = t.utils.data.DataLoader(LowResol_data,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=DC.num_workers,
                                                  drop_last=True)
    # transform net
    transformer = TransformerNet()
    if DC.load_model:
        transformer.load_state_dict(
            t.load(DC.load_model, map_location=lambda storage, loc: storage))

    transformer.to(device)

    # Loss net (vgg16)
    vgg = Vgg16().eval()
    vgg.to(device)

    for param in vgg.parameters():
        param.requires_grad = False

    optimizer = t.optim.Adam(transformer.parameters(), DC.base_lr)

    # Start training
    train_imgs = 0
    iteration = 0
    for epoch in range(DC.max_epoch):
        for i, ((high_data, _), (low_data, _)) in tqdm.tqdm(
                enumerate(zip(HighResol_dataloader, LowResol_dataloader))):

            train_imgs += batch_size
            iteration += 1

            optimizer.zero_grad()

            # Transformer net
            x = low_data.to(device)
            y = transformer(x)
            y = utils.normalize_batch(y)

            yc = high_data.to(device)
            yc = utils.normalize_batch(yc)

            features_y = vgg(y)
            features_yc = vgg(yc)

            # Content loss
            content_loss = DC.content_weight * \
                             nn.functional.mse_loss(features_y.relu2_2,
                                                    features_yc.relu2_2)
            #            content_loss = DC.content_weight * \
            #                            nn.functional.mse_loss(features_y.relu3_3,
            #                                                   features_yc.relu3_3)

            content_loss.backward()
            optimizer.step()

            if iteration % DC.show_iter == 0:
                print('\nepoch: ', epoch)
                print('content loss: ', content_loss.data)
                print()

        if (epoch + 1) % 10 == 0:
            t.save(transformer.state_dict(), '{}_style.pth'.format(epoch))
Example #16
0
def train(args):
    # 将torch.Tensor分配到的设备的对象CPU或GPU
    device = torch.device("cuda" if args.cuda else "cpu")
    # 初始化随机种子
    np.random.seed(args.seed)
    # 为CPU设置种子用于生成随机数
    torch.manual_seed(args.seed)
    """
        将多个transform组合起来使用
    """
    transform = transforms.Compose([
        # 重新设定大小
        transforms.Resize(args.image_size),
        # 将给定的Image进行中心切割
        transforms.CenterCrop(args.image_size),
        # 把Image转成张量Tensor格式,大小范围为[0,1]
        transforms.ToTensor(),
        # 使用lambd作为转换器
        transforms.Lambda(lambda x: x.mul(255))
    ])
    # 使用ImageFolder数据加载器,传入数据集的路径
    # transform:一个函数,原始图片作为输入,返回一个转换后的图片
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    # 把上一步做成的数据集放入Data.DataLoader中,可以生成一个迭代器
    # batch_size:int,每个batch加载多少样本
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)
    # 加载模型TransformerNet到设备上
    transformer = TransformerNet().to(device)
    # 我们选择Adam作为优化器
    optimizer = Adam(transformer.parameters(), args.lr)
    # 均方损失函数
    mse_loss = torch.nn.MSELoss()
    # 加载模型Vgg16到设备上
    vgg = Vgg16(requires_grad=False).to(device)
    # 风格图片的处理
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    # 载入风格图片
    style = utils.load_image(args.style_image, size=args.style_size)
    # 处理风格图片
    style = style_transform(style)
    # repeat(*sizes)沿着指定的维度重复tensor
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    # 特征风格归一化
    features_style = vgg(utils.normalize_batch(style))
    # 风格特征图计算Gram矩阵
    gram_style = [utils.gram_matrix(y) for y in features_style]
    # 迭代训练
    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            # 把梯度置零,也就是把loss关于weight的导数变成0
            optimizer.zero_grad()

            y = transformer(x.to(device))

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x.cuda())
            # 计算内容损失
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)
            # 计算风格损失
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight
            # 总损失
            total_loss = content_loss + style_loss
            # 反向传播
            total_loss.backward()
            # 更新参数
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            # 准备打印相关信息,args.log_interval是最开头设置的好了的参数
            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)
            # 生成训练好的风格图片模型 and (batch_id + 1) % args.checkpoint_interval == 0
            if args.checkpoint_model_dir is not None:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #17
0
def train(args):
    """Meta train the model"""

    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # first move parameters to GPU
    transformer = TransformerNet().to(device)
    vgg = Vgg16(requires_grad=False).to(device)
    global optimizer
    optimizer = Adam(transformer.parameters(), args.meta_lr)
    global mse_loss
    mse_loss = torch.nn.MSELoss()

    content_loader, style_loader, query_loader = get_data_loader(args)

    content_weight = args.content_weight
    style_weight = args.style_weight
    lr = args.lr

    writer = SummaryWriter(args.log_dir)

    for iteration in trange(args.max_iter):
        transformer.train()
        
        # bookkeeping
        # using state_dict causes problems, use named_parameters instead
        all_meta_grads = []
        avg_train_c_loss = 0.0
        avg_train_s_loss = 0.0
        avg_train_loss = 0.0
        avg_eval_c_loss = 0.0
        avg_eval_s_loss = 0.0
        avg_eval_loss = 0.0

        contents = content_loader.next()[0].to(device)
        features_contents = vgg(utils.normalize_batch(contents))
        querys = query_loader.next()[0].to(device)
        features_querys = vgg(utils.normalize_batch(querys))

        # learning rate scheduling
        lr = args.lr / (1.0 + iteration * 2.5e-5)
        meta_lr = args.meta_lr / (1.0 + iteration * 2.5e-5)
        for param_group in optimizer.param_groups:
            param_group['lr'] = meta_lr

        for i in range(args.meta_batch_size):
            # sample a style
            style = style_loader.next()[0].to(device)
            style = style.repeat(args.iter_batch_size, 1, 1, 1)
            features_style = vgg(utils.normalize_batch(style))
            gram_style = [utils.gram_matrix(y) for y in features_style]

            fast_weights = OrderedDict((name, param) for (name, param) in transformer.named_parameters() if re.search(r'in\d+\.', name))
            for j in range(args.meta_step):
                # run forward transformation on contents
                transformed = transformer(contents, fast_weights)

                # compute loss
                features_transformed = vgg(utils.standardize_batch(transformed))
                loss, c_loss, s_loss = loss_fn(features_transformed, features_contents, gram_style, content_weight, style_weight)

                # compute grad
                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)

                # update fast weights
                fast_weights = OrderedDict((name, param - lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))
            
            avg_train_c_loss += c_loss.item()
            avg_train_s_loss += s_loss.item()
            avg_train_loss += loss.item()

            # run forward transformation on querys
            transformed = transformer(querys, fast_weights)
            
            # compute loss
            features_transformed = vgg(utils.standardize_batch(transformed))
            loss, c_loss, s_loss = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight)
            
            grads = torch.autograd.grad(loss / args.meta_batch_size, transformer.parameters())
            all_meta_grads.append({name: g for ((name, _), g) in zip(transformer.named_parameters(), grads)})

            avg_eval_c_loss += c_loss.item()
            avg_eval_s_loss += s_loss.item()
            avg_eval_loss += loss.item()
        
        writer.add_scalar("Avg_Train_C_Loss", avg_train_c_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Train_S_Loss", avg_train_s_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Train_Loss", avg_train_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_C_Loss", avg_eval_c_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_S_Loss", avg_eval_s_loss / args.meta_batch_size, iteration + 1)
        writer.add_scalar("Avg_Eval_Loss", avg_eval_loss / args.meta_batch_size, iteration + 1)

        # compute dummy loss to refresh buffer
        transformed = transformer(querys)
        features_transformed = vgg(utils.standardize_batch(transformed))
        dummy_loss, _, _ = loss_fn(features_transformed, features_querys, gram_style, content_weight, style_weight)

        meta_updates(transformer, dummy_loss, all_meta_grads)

        if args.checkpoint_model_dir is not None and (iteration + 1) % args.checkpoint_interval == 0:
            transformer.eval().cpu()
            ckpt_model_filename = "iter_" + str(iteration + 1) + ".pth"
            ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
            torch.save(transformer.state_dict(), ckpt_model_path)
            transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "Final_iter_" + str(args.max_iter) + "_" + \
                          str(args.content_weight) + "_" + \
                          str(args.style_weight) + "_" + \
                          str(args.lr) + "_" + \
                          str(args.meta_lr) + "_" + \
                          str(args.meta_batch_size) + "_" + \
                          str(args.meta_step) + "_" + \
                          time.ctime() + ".pth"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print "Done, trained model saved at {}".format(save_model_path)
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        # utils.RGB2LAB(),
        transforms.ToTensor(),
        # utils.LAB2Tensor(),
    ])
    pert_transform = transforms.Compose([utils.ColorPerturb()])
    trainset = utils.FlatImageFolder(args.dataset, transform, pert_transform)
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=4)
    model = TransformerNet()
    if args.gpus is not None:
        model = nn.DataParallel(model, device_ids=args.gpus)
    else:
        model = nn.DataParallel(model)
    if args.resume:
        state_dict = torch.load(args.resume)
        model.load_state_dict(state_dict)

    if args.cuda:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    criterion = nn.MSELoss()

    start_time = datetime.now()

    for e in range(args.epochs):
        model.train()
        count = 0
        acc_loss = 0.0
        for batchi, (pert_img, ori_img) in enumerate(trainloader):
            count += len(pert_img)
            if args.cuda:
                pert_img = pert_img.cuda(non_blocking=True)
                ori_img = ori_img.cuda(non_blocking=True)

            optimizer.zero_grad()

            rec_img = model(pert_img)
            loss = criterion(rec_img, ori_img)
            loss.backward()
            optimizer.step()

            acc_loss += loss.item()
            if (batchi + 1) % args.log_interval == 0:
                mesg = '{}\tEpoch {}: [{}/{}]\ttotal loss: {:.6f}'.format(
                    time.ctime(), e + 1, count, len(trainset),
                    acc_loss / (args.log_interval))
                print(mesg)
                acc_loss = 0.0

        if args.checkpoint_dir and e + 1 != args.epochs:
            model.eval().cpu()
            ckpt_filename = 'ckpt_epoch_' + str(e + 1) + '.pth'
            ckpt_path = osp.join(args.checkpoint_dir, ckpt_filename)
            torch.save(model.state_dict(), ckpt_path)
            model.cuda().train()
            print('Checkpoint model at epoch %d saved' % (e + 1))

    model.eval().cpu()
    if args.save_model_name:
        model_filename = args.save_model_name
    else:
        model_filename = "epoch_" + str(args.epochs) + "_" + str(
            time.ctime()).replace(' ', '_') + ".model"
    model_path = osp.join(args.save_model_dir, model_filename)
    torch.save(model.state_dict(), model_path)

    end_time = datetime.now()

    print('Finished training after %s, trained model saved at %s' %
          (end_time - start_time, model_path))
Example #19
0
def run_train(args):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    print('running training process...')
    if args.semantic == 1:
        print(
            'multilabels semantic feedforward neural style transfer training...'
        )
    elif args.semantic == 0:
        print('normal feedforward neural style transfer training...')

    if args.semantic == 1:
        loss_net, content_losses, style_losses, content_masks, n_channels = train_preparation_mask(
            args)
    elif args.semantic == 0:
        loss_net, content_losses, style_losses, n_channels = train_preparation(
            args)

    if args.backend == 'cudnn':
        torch.backends.cudnn.enabled = True

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transform_net = TransformerNet(n_channels).to(device)
    mse_loss = nn.MSELoss()

    optimizer = optim.Adam(transform_net.parameters(), lr=args.learning_rate)

    iteration = [0]
    while iteration[0] <= args.epochs - 1:
        transform_net.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            stloss = 0.
            ctloss = 0.
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            #stack color_content_masks into x as input
            x, x_ori = x.to(device), x.to(device).clone()
            x = preprocess(x)
            x_ori = preprocess(x_ori)
            if args.semantic == 1:
                x = torch.cat((x, content_masks), 1)

            y = transform_net(x)

            #compute pixel loss
            if args.semantic == 1:
                y_pix = torch.cat((y, content_masks), 1)
            elif args.semantic == 0:
                y_pix = y
            pixloss = 0.
            if args.pixel_weight > 0:
                pixloss = mse_loss(x, y_pix) * args.pixel_weight

            #compute content loss and style loss

            for ctl in content_losses:
                ctl.mode = 'capture'
            loss_net(x_ori)
            for ctl in content_losses:
                ctl.mode = 'loss'
            for stl in style_losses:
                stl.mode = 'loss'
            loss_net(y)
            for ctl in content_losses:
                ctloss += mse_loss(ctl.input, ctl.target) * args.content_weight
            if args.semantic == 1:
                for stl in style_losses:
                    for u in range(len(stl.color_codes)):
                        input_msk = stl.input_masks[u].expand_as(stl.input)
                        input_masked = torch.mul(stl.input, input_msk)
                        input_msk_mean = torch.mean(stl.input_masks[u])
                        input_local_G = gram_matrix(input_masked)
                        if input_msk_mean > 0:
                            input_local_G.div(stl.input.nelement() *
                                              input_msk_mean)
                        loss_local = mse_loss(input_local_G, stl.target[u])
                        loss_local *= input_msk_mean
                        #larger target areas multiples smaller style weight
                        if input_msk_mean > 0.2:
                            stloss += loss_local * args.style_weights[0]
                        #smaller target areas multiples larger style weight
                        elif input_msk_mean <= 0.2:
                            #print('aaaaa')
                            stloss += loss_local * args.style_weights[1]
            elif args.semantic == 0:
                for stl in style_losses:
                    gram = gram_matrix(stl.input)
                    stloss += mse_loss(gram,
                                       stl.target) * args.style_weights[0]

            loss = ctloss + stloss + pixloss

            loss.backward()
            optimizer.step()

            agg_content_loss += ctloss.item()
            agg_style_loss += stloss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}, Epoch {}:\t[{}/{}], content: {:.6f}, style: {:.6f}, total: {:.6f}".format(
                    time.ctime(), iteration[0], count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)
            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transform_net.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    iteration[0] +
                    1) + "_batch_id_" + str(batch_id + 1) + "_semantic_" + str(
                        args.semantic) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transform_net.state_dict(), ckpt_model_path)
                transform_net.to(device).train()

        iteration[0] += 1

    #save final model
    transform_net.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_content_" + str(
            args.content_weight) + "_style_" + str(
                args.style_weights[0]) + "_semantic_" + str(
                    args.semantic) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transform_net.state_dict(), save_model_path)

    print("\n training process is Done!, trained model saved at",
          save_model_path)
Example #20
0
def train(start_epoch=0):
    np.random.seed(enums.seed)
    torch.manual_seed(enums.seed)

    if enums.cuda:
        torch.cuda.manual_seed(enums.seed)

    transform = transforms.Compose([
        transforms.Resize(enums.image_size),
        transforms.CenterCrop(enums.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(enums.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=enums.batch_size)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), enums.lr)
    if enums.subcommand == 'resume':
        ckpt_state = torch.load(enums.checkpoint_model)
        transformer.load_state_dict(ckpt_state['state_dict'])
        start_epoch = ckpt_state['epoch']
        optimizer.load_state_dict(ckpt_state['optimizer'])

    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(enums.style_image, size=enums.style_size)
    style = style_transform(style)
    style = style.expand(enums.batch_size, *style.size())  # N,C,H,W

    if enums.cuda:
        transformer.cuda()
        vgg.cuda()
        style = style.cuda()

    style_v = Variable(style)
    style_v = utils.normalize_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(start_epoch, enums.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(x)
            if enums.cuda:
                x = x.cuda()

            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = enums.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= enums.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % enums.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

        if enums.checkpoint_model_dir is not None and (
                e + 1) % enums.checkpoint_interval == 0:
            # transformer.eval()
            if enums.cuda:
                transformer.cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e + 1) + ".pth"
            ckpt_model_path = os.path.join(enums.checkpoint_model_dir,
                                           ckpt_model_filename)
            save_checkpoint(
                {
                    'epoch': e + 1,
                    'state_dict': transformer.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, ckpt_model_path)
            if enums.cuda:
                transformer.cuda()
            # transformer.train()

    # save model
    # transformer.eval()
    if enums.cuda:
        transformer.cpu()
    save_model_filename = "epoch_" + str(enums.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            enums.content_weight) + "_" + str(enums.style_weight) + ".model"
    save_model_path = os.path.join(enums.save_model_dir, save_model_filename)
    save_checkpoint(
        {
            'epoch': e + 1,
            'state_dict': transformer.state_dict(),
            'optimizer': optimizer.state_dict()
        }, save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #21
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    vis = utils.Visualizer(opt.env)

    # 数据加载
    transfroms = tv.transforms.Compose([
        tv.transforms.Scale(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s))

    # 损失网络 Vgg16
    vgg = Vgg16().eval()

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:
        transformer.cuda()
        style = style.cuda()
        vgg.cuda()

    # 风格图片的gram矩阵
    style_v = Variable(style, volatile=True)
    features_style = vgg(style_v)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    # 损失统计
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            if opt.use_gpu:
                x = x.cuda()
            x = Variable(x)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # 损失平滑
            content_meter.add(content_loss.data[0])
            style_meter.add(style_loss.data[0])

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # 可视化
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # 因为x和y经过标准化处理(utils.normalize_batch),所以需要将它们还原
                vis.img('output', (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))

        # 保存visdom和模型
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
Example #22
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([transforms.Scale(args.image_size),
                                    transforms.CenterCrop(args.image_size),
                                    transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x.mul(255))])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)

    transformer = TransformerNet()
    if (args.premodel != ""):
        transformer.load_state_dict(torch.load(args.premodel))
        print("load pretrain model:"+args.premodel)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style = style.repeat(args.batch_size, 1, 1, 1)
    style = utils.preprocess_batch(style)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    style_v = utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]


    hori=0 
    writer = SummaryWriter(args.logdir,comment=args.logdir)
    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        agg_cate_loss = 0.
        agg_cam_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()
            y = transformer(x)  
            xc = Variable(x.data.clone(), volatile=True)
            #print(y.size()) #(4L, 3L, 224L, 224L)

            
            # Calculate focus loss and category loss
            y_cam = utils.depreprocess_batch(y)
            y_cam = utils.subtract_mean_std_batch(y_cam) 
            
            xc_cam = utils.depreprocess_batch(xc)
            xc_cam = utils.subtract_mean_std_batch(xc_cam)
            

            del features_blobs[:]
            logit_x = net(xc_cam)
            logit_y = net(y_cam)
            
            label=[]
            cam_loss = 0
            for i in range(len(xc_cam)):
                h_x = F.softmax(logit_x[i])
                probs_x, idx_x = h_x.data.sort(0, True)
                label.append(idx_x[0])
                
                h_y = F.softmax(logit_y[i])
                probs_y, idx_y = h_y.data.sort(0, True)
                
                x_cam = returnCAM(features_blobs[0][i], weight_softmax, idx_x[0])
                x_cam = Variable(x_cam.data,requires_grad = False)
 
                y_cam = returnCAM(features_blobs[1][i], weight_softmax, idx_y[0])
                
                cam_loss += mse_loss(y_cam, x_cam)
            
            #the focus loss
            cam_loss *= 80
            #the category loss
            label = Variable(torch.LongTensor(label),requires_grad = False).cuda()
            cate_loss = 10000 * torch.nn.CrossEntropyLoss()(logit_y,label)
         
         

           
            y = utils.subtract_imagenet_mean_batch(y)
            xc = utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            #f_xc_c = Variable(features_xc[1].data, requires_grad=False)
            #content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)


            f_xc_c = Variable(features_xc[2].data, requires_grad=False)
            content_loss = args.content_weight * mse_loss(features_y[2], f_xc_c)
            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(gram_y, gram_s[:n_batch, :, :])
            #add the total four loss and backward
            total_loss = style_loss + content_loss  + cam_loss + cate_loss
            total_loss.backward()
            optimizer.step()

            #something for display
            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]
            agg_cate_loss += cate_loss.data[0]
            agg_cam_loss += cam_loss.data[0]
            
            writer.add_scalar("Loss_Cont", agg_content_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_Style", agg_style_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_CAM", agg_cam_loss / (batch_id + 1), hori)
            writer.add_scalar("Loss_Cate", agg_cate_loss / (batch_id + 1), hori)
            hori += 1
            
            if (batch_id + 1) % args.log_interval == 0:
               mesg = "{}Epoch{}:[{}/{}] content:{:.2f} style:{:.2f} cate:{:.2f} cam:{:.2f}  total:{:.2f}".format(
                    time.strftime("%a %H:%M:%S"),e + 1, count, len(train_dataset),
                                 agg_content_loss / (batch_id + 1),
                                 agg_style_loss / (batch_id + 1),
                                 agg_cate_loss / (batch_id + 1),
                                 agg_cam_loss / (batch_id + 1),
                                 (agg_content_loss + agg_style_loss + agg_cate_loss + agg_cam_loss ) / (batch_id + 1)
               )
               print(mesg)
               
            if (batch_id + 1) % 2500 == 0:    
                transformer.eval()
                transformer.cpu()
                save_model_filename = "epoch_" + str(e+1) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                    args.content_weight) + "_" + str(args.style_weight) + ".model"
                save_model_path = os.path.join(args.save_model_dir, save_model_filename)
                torch.save(transformer.state_dict(), save_model_path)
                transformer.cuda()
                transformer.train()
                print("saved at ",count)
    
    
    
    
    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)
    
    writer.close()
    print("\nDone, trained model saved at", save_model_path)
Example #23
0
def train():
    train_gpu_id = DC.train_gpu_id
    device = t.device('cuda', train_gpu_id) if DC.use_gpu else t.device('cpu')

    transforms = T.Compose([
      T.Resize(DC.input_size),
      T.CenterCrop(DC.input_size),
      T.ToTensor(),
      T.Lambda(lambda x: x*255)
    ])

    train_dir = DC.train_content_dir
    batch_size = DC.train_batch_size

    train_data = ImageFolder(train_dir, transform=transforms)

    num_train_data = len(train_data)

    train_dataloader = t.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=DC.num_workers,
                                               drop_last=True)
    # transform net
    transformer = TransformerNet()
    if DC.load_model:
        transformer.load_state_dict(
          t.load(DC.load_model, 
                 map_location=lambda storage, loc: storage))

    transformer.to(device)

    # Loss net (vgg16)
    vgg = Vgg16().eval()
    vgg.to(device)

    for param in vgg.parameters():
        param.requires_grad = False

    optimizer = t.optim.Adam(transformer.parameters(), DC.base_lr)

    # Get the data from style image
    ys = utils.get_style_data(DC.style_img)
    ys = ys.to(device)

    # The Gram matrix of the style image
    with t.no_grad():
        features_ys = vgg(ys)

        gram_ys = [utils.gram_matrix(ys) for ys in features_ys]

    # Start training
    train_imgs = 0
    iteration = 0
    for epoch in range(DC.max_epoch):
        for i, (data, label) in tqdm.tqdm(enumerate(train_dataloader)):
            train_imgs += batch_size
            iteration += 1

            optimizer.zero_grad()
         
            # Transformer net
            x = data.to(device)
            y = transformer(x)

            x = utils.normalize_batch(x)
            yc = x
            y = utils.normalize_batch(y)

            features_y = vgg(y)
            features_yc = vgg(yc)

            # Content loss
            content_loss = DC.content_weight * \
                             nn.functional.mse_loss(features_y.relu2_2, 
                                                    features_yc.relu2_2)
#            content_loss = DC.content_weight * \
#                             nn.functional.mse_loss(features_y.relu3_3, 
#                                                    features_yc.relu3_3)

            # Style loss
            style_loss = 0.0
            for ft_y, gm_ys in zip(features_y, gram_ys):
                gm_y = utils.gram_matrix(ft_y)
                
                style_loss += nn.functional.mse_loss(gm_y, 
                                                     gm_ys.expand_as(gm_y))


            style_loss *= DC.style_weight

            # Total loss
            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            if iteration%DC.show_iter == 0: 
                print('\ncontent loss: ', content_loss.data)
                print('style loss: ', style_loss.data)
                print('total loss: ', total_loss.data)
                print()

        t.save(transformer.state_dict(), '{}_style.pth'.format(epoch))
Example #24
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 12, 'pin_memory': False}
    else:
        kwargs = {}
    from transform.color_op import Linearize, SRGB2XYZ, XYZ2CIE

    RGB2YUV = transforms.Compose([
        Linearize(),
        SRGB2XYZ(),
        XYZ2CIE()
    ])

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        RGB2YUV(),
        transforms.ToTensor(),
        # transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, **kwargs)

    transformer = TransformerNet(in_channels=2, out_channels=1)  # input: LS, predict: M
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    transformer = nn.DataParallel(transformer)

    if args.cuda:
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is requested, but related driver/device is not set properly.")
        transformer.cuda()

    for e in range(args.epochs):
        transformer.train()
        # agg_content_loss = 0.
        # agg_style_loss = 0.
        count = 0
        for batch_id, (imgs, _) in enumerate(train_loader):
            n_batch = len(imgs)
            count += n_batch
            optimizer.zero_grad()
            # First channel
            x = torch.cat([imgs[:, :1, :, :].clone(), imgs[:, -1:, :, :].clone()], dim=1)
            # Second and third channels
            gt = imgs[:, 1:2, :, :].clone()

            if args.cuda:
                x = x.cuda()
                gt = gt.cuda()

            y = transformer(x)

            total_loss = mse_loss(y, gt)
            total_loss.backward()
            optimizer.step()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  total_loss / (batch_id + 1)
                )
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    os.makedirs(args.save_model_dir, exist_ok=True)
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #25
0
def train(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        kwargs = {'num_workers': 0, 'pin_memory': False}
    else:
        kwargs = {}

    transform = transforms.Compose([
        transforms.Scale(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              **kwargs)

    transformer = TransformerNet()
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16()
    utils.init_vgg16(args.vgg_model_dir)
    vgg.load_state_dict(
        torch.load(os.path.join(args.vgg_model_dir, "vgg16.weight")))

    if args.cuda:
        transformer.cuda()
        vgg.cuda()

    style = utils.tensor_load_rgbimage(args.style_image, size=args.style_size)
    style = style.repeat(args.batch_size, 1, 1, 1)
    style = utils.preprocess_batch(style)
    if args.cuda:
        style = style.cuda()
    style_v = Variable(style, volatile=True)
    utils.subtract_imagenet_mean_batch(style_v)
    features_style = vgg(style_v)
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            x = Variable(utils.preprocess_batch(x))
            if args.cuda:
                x = x.cuda()

            y = transformer(x)

            xc = Variable(x.data.clone(), volatile=True)

            utils.subtract_imagenet_mean_batch(y)
            utils.subtract_imagenet_mean_batch(xc)

            features_y = vgg(y)
            features_xc = vgg(xc)

            f_xc_c = Variable(features_xc[1].data, requires_grad=False)

            content_loss = args.content_weight * mse_loss(
                features_y[1], f_xc_c)

            style_loss = 0.
            for m in range(len(features_y)):
                gram_s = Variable(gram_style[m].data, requires_grad=False)
                gram_y = utils.gram_matrix(features_y[m])
                style_loss += args.style_weight * mse_loss(
                    gram_y, gram_s[:n_batch, :, :])

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.data[0]
            agg_style_loss += style_loss.data[0]

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

    # save model
    transformer.eval()
    transformer.cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #26
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(
            args.image_size),  # the shorter side is resize to match image_size
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),  # to tensor [0,1]
        transforms.Lambda(lambda x: x.mul(255))  # convert back to [0, 255]
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)  # to provide a batch loader

    style_image = [f for f in os.listdir(args.style_image)]
    style_num = len(style_image)
    print(style_num)

    transformer = TransformerNet(style_num=style_num).to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.Resize(args.style_size),
        transforms.CenterCrop(args.style_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    style_batch = []

    for i in range(style_num):
        style = utils.load_image(args.style_image + style_image[i],
                                 size=args.style_size)
        style = style_transform(style)
        style_batch.append(style)

    style = torch.stack(style_batch).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)

            if n_batch < args.batch_size:
                break  # skip to next epoch when no enough images left in the last batch of current epoch

            count += n_batch
            optimizer.zero_grad()  # initialize with zero gradients

            batch_style_id = [
                i % style_num for i in range(count - n_batch, count)
            ]
            y = transformer(x.to(device), style_id=batch_style_id)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y.to(device))
            features_x = vgg(x.to(device))
            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[batch_style_id, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(
        args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(
            ':', '') + "_" + str(int(args.content_weight)) + "_" + str(
                int(args.style_weight)) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
def train(**kwargs):

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    vis = utils.Visualizer(opt.env)

    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    loader = get_loader(batch_size=1,
                        data_path=opt.data_path,
                        img_shape=opt.img_shape,
                        transform=transform)

    # 转换网络
    transformer = TransformerNet().cuda()
    # transformer.load_state_dict(t.load(opt.model_path, ))

    #if opt.model_path:
    #    transformer.load_state_dict(t.load(opt.model_path,map_location=lambda _s, _: _s))

    # 损失网络 Vgg16
    vgg = Vgg19().eval()
    depthnet = HourGlass().eval()
    depthnet.load_state_dict(t.load(opt.depth_path))
    # print(vgg)
    # BASNET
    net = BASNet(3, 1).cuda()
    net.load_state_dict(torch.load('./basnet.pth'))
    net.eval()

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), lr=opt.lr)

    # 获取风格图片的数据

    img = Image.open(opt.style_path)
    img = img.resize(opt.img_shape)
    img = transform(img).float()
    style = Variable(img, requires_grad=True).unsqueeze(0)
    vis.img('style', (style[0] * 0.225 + 0.45).clamp(min=0, max=1))

    if opt.use_gpu:
        transformer.cuda()
        style = style.cuda()
        vgg.cuda()
        depthnet.cuda()

    # 风格图片的gram矩阵
    style_v = Variable(style, volatile=True)
    features_style = vgg(style_v)
    gram_style = [Variable(utils.gram_matrix(y.data)) for y in features_style]

    # 损失统计
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()
    temporal_meter = tnt.meter.AverageValueMeter()
    long_temporal_meter = tnt.meter.AverageValueMeter()
    depth_meter = tnt.meter.AverageValueMeter()
    # tv_meter = tnt.meter.AverageValueMeter()
    kk = 0
    for count in range(opt.epoch):
        print('Training Start!!')
        content_meter.reset()
        style_meter.reset()
        temporal_meter.reset()
        long_temporal_meter.reset()
        depth_meter.reset()
        # tv_meter.reset()
        for step, frames in enumerate(loader):
            for i in tqdm.tqdm(range(1, len(frames))):
                kk += 1
                if (kk + 1) % 3000 == 0:
                    print('LR had changed')
                    for param in optimizer.param_groups:
                        param['lr'] = max(param['lr'] / 1.2, 1e-4)

                optimizer.zero_grad()
                x_t = frames[i].cuda()

                x_t1 = frames[i - 1].cuda()

                h_xt = transformer(x_t)

                h_xt1 = transformer(x_t1)
                depth_x_t = depthnet(x_t)
                depth_x_t1 = depthnet(x_t1)
                depth_h_xt = depthnet(h_xt)
                depth_h_xt1 = depthnet(h_xt1)

                img1 = h_xt1.data.cpu().squeeze(0).numpy().transpose(1, 2, 0)
                img2 = h_xt.data.cpu().squeeze(0).numpy().transpose(1, 2, 0)

                flow, mask = opticalflow(img1, img2)

                d1, d2, d3, d4, d5, d6, d7, d8 = net(x_t)
                a1pha1 = PROCESS(d1, x_t)
                del d1, d2, d3, d4, d5, d6, d7, d8

                d1, d2, d3, d4, d5, d6, d7, d8 = net(x_t1)
                a1pha2 = PROCESS(d1, x_t1)
                del d1, d2, d3, d4, d5, d6, d7, d8

                h_xt_features = vgg(h_xt)
                h_xt1_features = vgg(h_xt1)
                x_xt_features = vgg(a1pha1)
                x_xt1_features = vgg(a1pha2)

                # ContentLoss, conv3_2
                content_t = F.mse_loss(x_xt_features[2], h_xt_features[2])
                content_t1 = F.mse_loss(x_xt1_features[2], h_xt1_features[2])
                content_loss = opt.content_weight * (content_t1 + content_t)
                # StyleLoss
                style_t = 0
                style_t1 = 0
                for ft_y, gm_s in zip(h_xt_features, gram_style):
                    gram_y = gram_matrix(ft_y)
                    style_t += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
                for ft_y, gm_s in zip(h_xt1_features, gram_style):
                    gram_y = gram_matrix(ft_y)
                    style_t1 += F.mse_loss(gram_y, gm_s.expand_as(gram_y))

                style_loss = opt.style_weight * (style_t1 + style_t)

                # # depth loss
                depth_loss1 = F.mse_loss(depth_h_xt, depth_x_t)
                depth_loss2 = F.mse_loss(depth_h_xt1, depth_x_t1)
                depth_loss = opt.depth_weight * (depth_loss1 + depth_loss2)
                # # TVLoss
                # print(type(s_hxt[layer]),s_hxt[layer].size())
                # tv_loss = TVLoss(h_xt)

                #Long-temprol loss
                if (i - 1) % opt.sample_frames == 0:
                    frames0 = h_xt1.cpu()
                    long_img1 = frames0.data.cpu().squeeze(
                        0).numpy().transpose(1, 2, 0)
                # long_img2 = h_xt.data.cpu().squeeze(0).numpy().transpose(1,2,0)
                long_flow, long_mask = opticalflow(long_img1, img2)

                # Optical flow

                flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0).to(
                    torch.float32)
                long_flow = torch.from_numpy(long_flow).permute(
                    2, 0, 1).unsqueeze(0).to(torch.float32)

                # print(flow.size())
                # print(h_xt1.size())
                warped = warp(h_xt1.cpu().permute(0, 2, 3, 1), flow,
                              opt.img_shape[1], opt.img_shape[0]).cuda()
                long_warped = warp(frames0.cpu().permute(0, 2, 3,
                                                         1), long_flow,
                                   opt.img_shape[1], opt.img_shape[0]).cuda()
                long_temporal_loss = F.mse_loss(
                    h_xt, long_mask * long_warped.permute(0, 3, 1, 2))
                # print(warped.size())
                # tv.utils.save_image((warped.permute(0,3,1,2).data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1),
                #                     './warped.jpg')
                mask = mask.transpose(2, 0, 1)
                mask = torch.from_numpy(mask).cuda().to(torch.float32)
                # print(mask.shape)
                temporal_loss = F.mse_loss(h_xt,
                                           mask * warped.permute(0, 3, 1, 2))

                temporal_loss = opt.temporal_weight * temporal_loss
                long_temporal_loss = opt.long_temporal_weight * long_temporal_loss

                # Spatial Loss
                spatial_loss = content_loss + style_loss

                Loss = spatial_loss + depth_loss + temporal_loss + long_temporal_loss

                Loss.backward(retain_graph=True)
                optimizer.step()
                content_meter.add(float(content_loss.data))
                style_meter.add(float(style_loss.data))
                temporal_meter.add(float(temporal_loss.data))
                long_temporal_meter.add(float(long_temporal_loss.data))
                depth_meter.add(float(depth_loss.data))
                # tv_meter.add(float(tv_loss.data))

                vis.plot('temporal_loss', temporal_meter.value()[0])
                vis.plot('long_temporal_loss', long_temporal_meter.value()[0])
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                vis.plot('depth_loss', depth_meter.value()[0])
                # vis.plot('tv_loss', tv_meter.value()[0])

                if i % 10 == 0:
                    vis.img('input(t)',
                            (x_t.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                     max=1))
                    vis.img('output(t)',
                            (h_xt.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                      max=1))
                    vis.img('output(t-1)',
                            (h_xt1.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                       max=1))
                    print(
                        'epoch{},content loss:{},style loss:{},temporal loss:{},long temporal loss:{},depth loss:{},total loss{}'
                        .format(count, content_loss, style_loss, temporal_loss,
                                long_temporal_loss, depth_loss, Loss))
                    # print('epoch{},content loss:{},style loss:{},depth loss:{},total loss{}'
                    #       .format(count,content_loss, style_loss,depth_loss,Loss))

            vis.save([opt.env])
            torch.save(transformer.state_dict(), opt.model_path)
Example #28
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    target_transform = transforms.ToTensor()

    train_dataset = VFDataset(args.dataset, transform, target_transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    if args.load_model is not None:
        transformer.load_state_dict(torch.load(args.load_model))
    optimizer = Adam(transformer.parameters(), args.lr)
    # mse_loss = torch.nn.MSELoss()
    cosine_loss = torch.nn.CosineEmbeddingLoss()
    label = torch.ones(args.batch_size, 1, args.image_size,
                       args.image_size).to(device)

    # log_file = open(args.log_file, "w")

    for e in range(args.epochs):
        transformer.train()
        agg_loss = 0.
        count = 0
        for batch_id, (x, vf) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = utils.subtract_imagenet_mean_batch(x)
            x = x.to(device)
            y = transformer(x)
            vf = vf.to(device)

            # loss = mse_loss(y, vf)
            loss = cosine_loss(y, vf, label)
            loss.backward()
            optimizer.step()

            agg_loss += loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_loss / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #29
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                                  agg_content_loss / (batch_id + 1),
                                  agg_style_loss / (batch_id + 1),
                                  (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(mesg)

            if args.checkpoint_model_dir is not None and (batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir, ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
        args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #30
0
def train():
    device = torch.device("cuda")

    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    ])

    train_dataset = datasets.ImageFolder(dataset_path, transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    if resume_TransformerNet_from_file:
        if os.path.isfile(TransformerNet_path):
            print("=> loading checkpoint '{}'".format(TransformerNet_path))
            TransformerNet_par = torch.load(TransformerNet_path)
            for k in list(TransformerNet_par.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del TransformerNet_par[k]
            transformer.load_state_dict(TransformerNet_par)
            print("=> loaded checkpoint '{}'".format(TransformerNet_path))
        else:
            print("=> no checkpoint found at '{}'".format(TransformerNet_path))

    vgg = Vgg16(requires_grad=False).to(device)
    style = Image.open(style_image_path)
    style = transform(style)
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    model_fcrn = FCRN_for_transfer(batch_size=batch_size,
                                   requires_grad=False).to(device)
    model_fcrn_par = torch.load(FCRN_path)
    #start_epoch = model_fcrn_par['epoch']
    model_fcrn.load_state_dict(model_fcrn_par['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        FCRN_path, model_fcrn_par['epoch']))

    for e in range(epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_depth_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            depth_y = model_fcrn(y)
            depth_x = model_fcrn(x)

            content_loss = content_weight * mse_loss(features_y.relu2_2,
                                                     features_x.relu2_2)
            depth_loss = depth_weight * mse_loss(depth_y, depth_x)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= style_weight

            total_loss = content_loss + depth_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_depth_loss += depth_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tdepth: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_depth_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if checkpoint_model_dir is not None and (
                    batch_id + 1) % checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            content_weight) + "_" + str(style_weight) + ".model"
    save_model_path = os.path.join(save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #31
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #32
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    device = t.device('cuda') if opt.use_gpu else t.device("cpu")
    vis = util.Visualizer(opt.env)

    transfroms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    transform = TransformerNet()

    if opt.model_path:
        transform.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    transform = transform.to(device)

    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    optimizer = t.optim.Adam(transform.parameters(), opt.lr)

    style = util.get_style_data(opt.style_path)
    vis.img("style", (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    with t.no_grad():
        features_style = vgg(style)
        gram_style = [util.gram_matrix(y) for y in features_style]

    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            x = x.to(device)
            y = t.nn.parallel.data_parallel(transform, x, [0, 1])
            y = util.normalize_batch(y)
            x = util.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = util.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot("content_loss", content_meter.value()[0])
                vis.plot("style_loss", style_meter.value()[0])
                vis.img("output",
                        (y.data.cpu()[0] * 0.255 + 0.45).clamp(min=0, max=1))
                vis.img("input", (x.data.cpu()[0] * 0.255 + 0.45).clamp(min=0,
                                                                        max=1))

        vis.save([opt.env])
        t.save(transform.state_dict(),
               'checkpoints/' + time.ctime() + '%s_style.pth' % epoch)