style_loss *= args.lambda_style

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            epoch_metrics["content"] += [content_loss.item()]
            epoch_metrics["style"] += [style_loss.item()]
            epoch_metrics["total"] += [total_loss.item()]

            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [Content: %.2f (%.2f) Style: %.2f (%.2f) Total: %.2f (%.2f)]"
                % (
                    epoch + 1,
                    args.epochs,
                    batch_i,
                    10000,
                    content_loss.item(),
                    np.mean(epoch_metrics["content"]),
                    style_loss.item(),
                    np.mean(epoch_metrics["style"]),
                    total_loss.item(),
                    np.mean(epoch_metrics["total"]),
                ))

            batches_done = epoch * len(dataloader) + batch_i + 1
            # 保存不同batches_done下的模型
            if args.checkpoint_interval > 0 and batches_done % args.checkpoint_interval == 0:
                style_name = os.path.basename(args.style_image).split(".")[0]
                torch.save(transformer.state_dict(),
                           f"models/{style_name}_{batches_done}.pth")
def train_new_style(style_img_path, style_model_path):
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    # Basic params settings
    dataset_path = "datasets"  # 此处为coco14数据集的地址
    epochs = 1
    batch_size = 4
    # max_train_batch = 20000
    image_size = 256
    style_size = None
    # 以下三个参数值可能需要修改
    # 1. 1e3 1e6 1 ep=24000
    # 2. 1e2 1e5 0.5 ep=18000
    # 3. 5e1 5e4 0.01 ep=max lr=1e-4
    # 原论文lua实现中为1.0,5.0,1e-6
    # tensorflow版本中为7.5(15),100
    lambda_content = float(5e1)
    lambda_style = float(5e4)
    lambda_tv = float(0.01)
    lr = float(1e-4)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create dataloader for the training data
    train_dataset = datasets.ImageFolder(dataset_path,
                                         train_transform(image_size))
    dataloader = DataLoader(train_dataset, batch_size=batch_size)

    # Defines networks
    transformer = TransformerNet().to(device)
    vgg = VGG16(requires_grad=False).to(device)

    # Define optimizer and loss
    optimizer = Adam(transformer.parameters(), lr)
    l2_loss = torch.nn.MSELoss().to(device)

    # Load style image
    style = style_transform(style_size)(Image.open(style_img_path))
    style = style.repeat(batch_size, 1, 1, 1).to(device)

    # Extract style features
    features_style = vgg(style)
    gram_style = [gram_matrix(y) for y in features_style]

    for epoch in range(epochs):
        # epoch_metrics = {"content": [], "style": [], "total": []}
        for batch_i, (images, _) in enumerate(dataloader):
            optimizer.zero_grad()

            images_original = images.to(device)
            images_transformed = transformer(images_original)

            # Extract features
            features_original = vgg(images_original)
            features_transformed = vgg(images_transformed)

            # Compute content loss as MSE between features
            content_size = features_transformed.relu2_2.shape[0]*features_transformed.relu2_2.shape[1] * \
                features_transformed.relu2_2.shape[2] * \
                features_transformed.relu2_2.shape[3]
            content_loss = lambda_content*2 * \
                l2_loss(features_transformed.relu2_2,
                        features_original.relu2_2)
            content_loss /= content_size

            # Compute style loss as MSE between gram matrices
            style_loss = 0
            for ft_y, gm_s in zip(features_transformed, gram_style):
                gm_y = gram_matrix(ft_y)
                gm_size = gm_y.shape[0] * gm_y.shape[1] * gm_y.shape[2]
                style_loss += l2_loss(gm_y,
                                      gm_s[:images.size(0), :, :]) / gm_size
            style_loss *= lambda_style * 2

            # Compute tv loss
            y_tv = l2_loss(images_transformed[:, :, 1:, :],
                           images_transformed[:, :, :image_size - 1, :])
            x_tv = l2_loss(images_transformed[:, :, :, 1:],
                           images_transformed[:, :, :, :image_size - 1])
            tv_loss = lambda_tv*2 * \
                (x_tv/image_size + y_tv/image_size)/batch_size

            total_loss = content_loss + style_loss + tv_loss
            total_loss.backward()
            optimizer.step()

    # Save trained model
    torch.save(transformer.state_dict(), style_model_path)
Exemple #3
0
def train(**kwargs):
    # step1:config
    opt.parse(**kwargs)
    vis = Visualizer(opt.env)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    
    # step2:data
    # dataloader, style_img
    # 这次图片的处理和之前不一样,之前都是normalize,这次改成了lambda表达式乘以255,这种转化之后要给出一个合理的解释
    # 图片共分为两种,一种是原图,一种是风格图片,在作者的代码里,原图用于训练,需要很多,风格图片需要一张,用于损失函数
    
    transforms = T.Compose([
        T.Resize(opt.image_size),
        T.CenterCrop(opt.image_size),
        T.ToTensor(),
        T.Lambda(lambda x: x*255)    
    ])
    # 这次获取图片的方式和第七章一样,仍然是ImageFolder的方式,而不是dataset的方式
    dataset = tv.datasets.ImageFolder(opt.data_root,transform=transforms)
    dataloader = DataLoader(dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers,drop_last=True)
    
    style_img = get_style_data(opt.style_path) # 1*c*H*W
    style_img = style_img.to(device)
    vis.img('style_image',(style_img.data[0]*0.225+0.45).clamp(min=0,max=1)) # 个人觉得这个没必要,下次可以实验一下
    
    # step3: model:Transformer_net 和 损失网络vgg16
    # 整个模型分为两部分,一部分是转化模型TransformerNet,用于转化原始图片,一部分是损失模型Vgg16,用于评价损失函数,
    # 在这里需要注意一下,Vgg16只是用于评价损失函数的,所以它的参数不参与反向传播,只有Transformer的参数参与反向传播,
    # 也就意味着,我们只训练TransformerNet,只保存TransformerNet的参数,Vgg16的参数是在网络设计时就已经加载进去的。
    # Vgg16是以验证model.eval()的方式在运行,表示其中涉及到pooling等层会发生改变
    # 那模型什么时候开始model.eval()呢,之前是是val和test中就会这样设置,那么Vgg16的设置理由是什么?
    # 这里加载模型的时候,作者使用了简单的map_location的记录方法,更轻巧一些
    # 发现作者在写这些的时候越来越趋向方便的方式
    # 在cuda的使用上,模型的cuda是直接使用的,而数据的cuda是在正式训练的时候才使用的,注意一下两者的区别
    # 在第七章作者是通过两种方式实现网络分离的,一种是对于前面网络netg,进行 fake_img = netg(noises).detach(),使得非叶子节点变成一个类似不需要邱求导的叶子节点
    # 第四章还需要重新看,
    
    transformer_net = TransformerNet()
    
    if opt.model_path:
        transformer_net.load_state_dict(t.load(opt.model_path,map_location= lambda _s, _: _s))    
    transformer_net.to(device)
    

    
    # step3: criterion and optimizer
    optimizer = t.optim.Adam(transformer_net.parameters(),opt.lr)
    # 此通过vgg16实现的,损失函数包含两个Gram矩阵和均方误差,所以,此外,我们还需要求Gram矩阵和均方误差
    vgg16 = Vgg16().eval() # 待验证
    vgg16.to(device)
    # vgg的参数不需要倒数,但仍然需要反向传播
    # 回头重新考虑一下detach和requires_grad的区别
    for param in vgg16.parameters():
        param.requires_grad = False
    criterion = t.nn.MSELoss(reduce=True, size_average=True)
    
    
    # step4: meter 损失统计
    style_meter = meter.AverageValueMeter()
    content_meter = meter.AverageValueMeter()
    total_meter = meter.AverageValueMeter()
    
    # step5.2:loss 补充
    # 求style_image的gram矩阵
    # gram_style:list [relu1_2,relu2_2,relu3_3,relu4_3] 每一个是b*c*c大小的tensor
    with t.no_grad():
        features = vgg16(style_img)
        gram_style = [gram_matrix(feature) for feature in features]
    # 损失网络 Vgg16
    # step5: train
    for epoch in range(opt.epoches):
        style_meter.reset()
        content_meter.reset()
        
        # step5.1: train
        for ii,(data,_) in tqdm(enumerate(dataloader)):
            optimizer.zero_grad()
            # 这里作者没有进行 Variable(),与之前不同
            # pytorch 0.4.之后tensor和Variable不再严格区分,创建的tensor就是variable
            # https://mp.weixin.qq.com/s?__biz=MzI0ODcxODk5OA==&mid=2247494701&idx=2&sn=ea8411d66038f172a2f553770adccbec&chksm=e99edfd4dee956c23c47c7bb97a31ee816eb3a0404466c1a57c12948d807c975053e38b18097&scene=21#wechat_redirect
            data = data.to(device)
            y = transformer_net(data)
            # vgg对输入的图片需要进行归一化
            data = normalize_batch(data)
            y = normalize_batch(y)

           
            feature_data = vgg16(data)
            feature_y = vgg16(y) 
            # 疑问??现在的feature是一个什么样子的向量?
            
            # step5.2: loss:content loss and style loss
            # content_loss
            # 在这里和书上的讲的不一样,书上是relu3_3,代码用的是relu2_2
            # https://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medium=referral
            # 均方误差指的是一个像素点的损失,可以理解N*b*h*w个元素加起来,然后除以N*b*h*w
            # 随机梯度下降法本身就是对batch内loss求平均后反向传播
            content_loss = opt.content_weight*criterion(feature_y.relu2_2,feature_data.relu2_2)
            # style loss
            # style loss:relu1_2,relu2_2,relu3_3,relu3_4 
            # 此时需要求每一张图片的gram矩阵
            
            style_loss = 0
            # tensor也可以 for i in tensor:,此时只拆解外面一层的tensor
            # ft_y:b*c*h*w, gm_s:1*c*h*w
            for ft_y, gm_s in zip(feature_y, gram_style):
                gram_y = gram_matrix(ft_y)
                style_loss += criterion(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight
            
            total_loss = content_loss + style_loss
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            #import ipdb
            #ipdb.set_trace()
            # 获取tensor的值 tensor.item()   tensor.tolist()
            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())
            total_meter.add(total_loss.item())
            
            # step5.3: visualize
            if (ii+1)%opt.print_freq == 0 and opt.vis:
                # 为什么总是以这种形式进行debug
                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()
                vis.plot('content_loss',content_meter.value()[0])
                vis.plot('style_loss',style_meter.value()[0])
                vis.plot('total_loss',total_meter.value()[0])
                # 因为现在data和y都已经经过了normalize,变成了-2~2,所以需要把它变回去0-1
                vis.img('input',(data.data*0.225+0.45)[0].clamp(min=0,max=1))
                vis.img('output',(y.data*0.225+0.45)[0].clamp(min=0,max=1))
            
        # step 5.4 save and validate and visualize
        if (epoch+1) % opt.save_every == 0:
            t.save(transformer_net.state_dict(), 'checkpoints/%s_style.pth' % epoch)
            # 保存图片的几种方法,第七章的是 
            # tv.utils.save_image(fix_fake_imgs,'%s/%s.png' % (opt.img_save_path, epoch),normalize=True, range=(-1,1))
            # vis.save竟然没找到  我的神   
            vis.save([opt.env])
Exemple #4
0
            total_loss.backward()
            optimizer.step()

            epoch_metrics["content"] += [content_loss.item()]
            epoch_metrics["style"] += [style_loss.item()]
            epoch_metrics["total"] += [total_loss.item()]

            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [Content: %.2f (%.2f) Style: %.2f (%.2f) Total: %.2f (%.2f)]"
                % (
                    epoch + 1,
                    args.epochs,
                    batch_i,
                    len(train_dataset),
                    content_loss.item(),
                    np.mean(epoch_metrics["content"]),
                    style_loss.item(),
                    np.mean(epoch_metrics["style"]),
                    total_loss.item(),
                    np.mean(epoch_metrics["total"]),
                )
            )

            batches_done = epoch * len(dataloader) + batch_i + 1
            if batches_done % args.sample_interval == 0:
                save_sample(batches_done)

            if args.checkpoint_interval > 0 and batches_done % args.checkpoint_interval == 0:
                style_name = os.path.basename(args.style_image).split(".")[0]
                torch.save(transformer.state_dict(), f"checkpoints/{style_name}_{batches_done}.pth")
Exemple #5
0
                    args.epochs,
                    batch_i,
                    len(train_dataset),
                    content_loss.item(),
                    np.mean(epoch_metrics["content"]),
                    style_loss.item(),
                    np.mean(epoch_metrics["style"]),
                    total_loss.item(),
                    np.mean(epoch_metrics["total"]),
                ))
            # If the loss explodes or vanishes, stop here.
            abort = False
            for x in (style_loss.item(), total_loss.item(),
                      content_loss.item()):
                if math.isinf(x) or math.isnan(x):
                    print("Gradient vanished or exploded. Saving model.")
                    abort = True
            batches_done = epoch * len(dataloader) + batch_i + 1

            if args.checkpoint_interval > 0 and batches_done % args.checkpoint_interval == 0 or abort:
                checkpoint_folder = os.path.join(
                    TRAINING_DIR, "{}-training".format(style_name),
                    "checkpoints")
                os.makedirs(checkpoint_folder, exist_ok=True)
                checkpoint_path = os.path.join(
                    checkpoint_folder,
                    "{}_{}.pth".format(style_name, batches_done))
                torch.save(transformer.state_dict(), checkpoint_path)
                if abort:
                    break
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    data_train = load_data(args)
    iterator = data_train

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(weights=args.vgg16, requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        count = 0
        if args.noise_count:
            noiseimg_n = np.zeros((3, args.image_size, args.image_size),
                                  dtype=np.float32)
            # Preparing noise image.
            for n_c in range(args.noise_count):
                x_n = random.randrange(args.image_size)
                y_n = random.randrange(args.image_size)
                noiseimg_n[0][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg_n[1][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg_n[2][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg = torch.from_numpy(noiseimg_n)
                noiseimg = noiseimg.to(device)
        for batch_id, sample in enumerate(iterator):
            x = sample['image']
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            if args.noise_count:
                # Adding the noise image to the source image.
                noisy_x = x + noiseimg
                noisy_y = transformer(noisy_x)
                noisy_y = utils.normalize_batch(noisy_y)

            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            L_feat = args.lambda_feat * mse_loss(features_y.relu2_2,
                                                 features_x.relu2_2)

            L_style = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                L_style += mse_loss(gm_y, gm_s[:n_batch, :, :])
            L_style *= args.lambda_style

            L_tv = (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
                    torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))

            L_tv *= args.lambda_tv

            if args.noise_count:
                L_pop = args.lambda_noise * F.mse_loss(y, noisy_y)
                L = L_feat + L_style + L_tv + L_pop
                print(
                    'Epoch {},{}/{}. Total loss: {}. Loss distribution: feat {}, style {}, tv {}, pop {}'
                    .format(e, batch_id, len(data_train), L.data,
                            L_feat.data / L.data, L_style.data / L.data,
                            L_tv.data / L.data, L_pop.data / L.data))
            else:
                L = L_feat + L_style + L_tv
                print(
                    'Epoch {},{}/{}. Total loss: {}. Loss distribution: feat {}, style {}, tv {}'
                    .format(e, batch_id, len(data_train), L.data,
                            L_feat.data / L.data, L_style.data / L.data,
                            L_tv.data / L.data))
            L = L_style * 1e10 + L_feat * 1e5
            L.backward()
            optimizer.step()

    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Exemple #7
0
def train():
    parser = argparse.ArgumentParser(description='parser for style transfer')
    parser.add_argument('--dataset_path',
                        type=str,
                        default=r'C:\Users\Dewey\data\celeba',
                        help='path to training dataset')
    parser.add_argument('--style_image',
                        type=str,
                        default='mosaic.jpg',
                        help='path to style img')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--image_size',
                        type=int,
                        default=256,
                        help='training image size')
    parser.add_argument('--style_img_size',
                        type=int,
                        default=256,
                        help='style image size')
    parser.add_argument("--lambda_content",
                        type=float,
                        default=1e5,
                        help="Weight for content loss")
    parser.add_argument("--lambda_style",
                        type=float,
                        default=1e10,
                        help="Weight for style loss")
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument("--checkpoint_model",
                        type=str,
                        help="Optional path to checkpoint model")
    parser.add_argument("--checkpoint_interval",
                        type=int,
                        default=2000,
                        help="Batches between saving model")
    parser.add_argument("--sample_interval",
                        type=int,
                        default=1000,
                        help="Batches between saving image samples")
    parser.add_argument('--sample_format',
                        type=str,
                        default='jpg',
                        help='sample image format')
    args = parser.parse_args()

    style_name = args.style_image.split('/')[-1].split('.')[0]
    os.makedirs(f'images/outputs/{style_name}-training',
                exist_ok=True)  # f-string格式化字符串
    os.makedirs('checkpoints', exist_ok=True)

    def save_sample(batch):
        transformer.eval()
        with torch.no_grad():
            output = transformer(image_samples.to(device))
            img_grid = denormalize(
                torch.cat((image_samples.cpu(), output.cpu()), 2))
            save_image(img_grid,
                       f"images/outputs/{style_name}-training/{batch}.jpg",
                       nrow=4)
            transformer.train()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = datasets.ImageFolder(args.dataset_path,
                                         train_transform(args.image_size))
    dataloader = DataLoader(train_dataset,
                            batch_size=args.batch_size,
                            shuffle=True)

    transformer = TransformerNet().to(device)
    vgg = VGG16(requires_grad=False).to(device)

    if args.checkpoint_model:
        transformer.load_state_dict(torch.load(args.checkpoint_model))

    optimizer = Adam(transformer.parameters(), lr=args.lr)
    l2_loss = nn.MSELoss().to(device)

    # load style image
    style = style_transform(args.style_img_size)(Image.open(args.style_image))
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    # style_image features
    style_features = vgg(style)
    gram_style = [gram(x) for x in style_features]

    # visualization the image
    image_samples = []
    for path in random.sample(
            glob.glob(f'{args.dataset_path}/*/*.{args.sample_format}'), 8):
        image_samples += [
            style_transform(
                (args.image_size, args.image_size))(Image.open(path))
        ]
    image_samples = torch.stack(image_samples)
    c_loss = 0
    s_loss = 0
    t_loss = 0

    for epoch in range(args.epochs):
        for i, (img, _) in enumerate(dataloader):

            optimizer.zero_grad()

            image_original = img.to(device)
            image_transformed = transformer(image_original)

            origin_features = vgg(image_original)
            transformed_features = vgg(image_transformed)

            content_loss = args.lambda_content * l2_loss(
                transformed_features.relu_2_2, origin_features.relu_2_2)

            style_loss = 0
            for ii, jj in zip(transformed_features, gram_style):
                gram_t_features = gram(ii)
                style_loss += l2_loss(gram_t_features,
                                      jj[:img.size(), :, :])  # buyiyang
            style_loss *= args.lambda_style

            loss = content_loss + style_loss
            loss.backward()
            optimizer.step()

            c_loss += content_loss.item()
            s_loss += style_loss.item()
            t_loss += loss.item()
            print(
                '[Epoch %d/%d] [Batch %d/%d] [Content: %.2f (%.2f) Style: %.2f (%.2f) Total: %.2f (%.2f)]'
                % (
                    epoch + 1,
                    args.epochs,
                    i,
                    len(train_dataset),
                    content_loss.item(),
                    np.mean(c_loss),
                    style_loss.item(),
                    np.mean(s_loss),
                    loss.item(),
                    np.mean(t_loss),
                ))

            batches_done = epoch * len(dataloader) + i + 1
            if batches_done % args.sample_interval == 0:
                save_sample(batches_done)

            if args.checkpoint_interval > 0 and batches_done % args.checkpoint_interval == 0:
                style_name = os.path.basename(args.style_image).split(".")[0]
                torch.save(transformer.state_dict(),
                           f"checkpoints/{style_name}_{batches_done}.pth")