Ejemplo n.º 1
0
def stylize(**kwargs):
    opt = Config()

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    device = t.device('cuda' if t.cuda.is_available() else 'cpu')

    # 图片处理
    content_image = PIL.Image.open(opt.content_path)
    content_transform = tv.transforms.Compose(
        [tv.transforms.ToTensor(),
         tv.transforms.Lambda(lambda x: x * 255)])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device).detach()

    # 模型
    style_model = TransformerNet().eval()
    style_model.load_state_dict(
        t.load(opt.model_path, map_location=t.device('cpu')))
    style_model.to(device)

    # 风格迁移和保存
    output = style_model(content_image)
    output_data = output.cpu().data[0]
    tv.utils.save_image((output_data / 255).clamp(min=0, max=1),
                        opt.result_path)
Ejemplo n.º 2
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet().eval()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(
                    ".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image,
                                            args.export_onnx).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
Ejemplo n.º 3
0
def stylize(args):
    device = torch.device("cpu")

    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(args.model)
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)
        output = style_model(content_image)
       
    utils.save_image(args.output_image, output[0])
Ejemplo n.º 4
0
def train(**kwargs):
    opt = Config()
    for _k, _v in kwargs.items():
        setattr(opt, _k, _v)

    device = t.device("cuda" if t.cuda.is_available() else "cpu")
    vis = utils.Visualizer(opt.env)

    # 数据加载
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transforms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # 风格转换网络
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=t.device('cpu')))
    transformer.to(device)

    # 损失网络 Vgg16
    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # 优化器
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # 获取风格图片的数据
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    # 风格图片的gramj矩阵
    with t.no_grad():
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]

    # 损失统计
    style_loss_avg = 0
    content_loss_avg = 0

    for epoch in range(opt.epoches):
        for ii, (x, _) in tqdm(enumerate(dataloader)):

            # 训练
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            # print(y.size())
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_x = vgg(x)
            features_y = vgg(y)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu3_3, features_x.relu3_3)

            # style loss
            style_loss = 0
            for ft_y, gm_s in zip(features_y, gram_style):
                with t.no_grad():
                    gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            content_loss_avg += content_loss.item()
            style_loss_avg += style_loss.item()

            if (ii + 1) % opt.plot_every == 0:
                vis.plot('content_loss', content_loss_avg / opt.plot_every)
                vis.plot('style_loss', style_loss_avg / opt.plot_every)
                content_loss_avg = 0
                style_loss_avg = 0
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

            if (ii + 1) % opt.save_every == 0:
                vis.save([opt.env])
                t.save(transformer.state_dict(),
                       'checkpoints/%s_style.pth' % (ii + 1))
Ejemplo n.º 5
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")
    if args.backbone == "vgg":
        content_layer = ['relu_4']
        style_layer = ['relu_2', 'relu_4', 'relu_7', 'relu_10']
    elif args.backbone == "resnet":
        content_layer = ["conv_3"]
        style_layer = ["conv_1", "conv_2", "conv_3", "conv_4"]

    total_layer = list(dict.fromkeys(content_layer + style_layer))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    if args.backbone == "vgg":
        loss_model = vgg16().eval().to(device)
    elif args.backbone == "resnet":
        loss_model = resnet().eval().to(device)

    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    feature_style = loss_model(utils.normalize_batch(style), style_layer)
    gram_style = {
        key: utils.gram_matrix(val)
        for key, val in feature_style.items()
    }

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            feature_y = loss_model(y, total_layer)
            feature_x = loss_model(x, content_layer)

            content_loss = 0.
            for layer in content_layer:
                content_loss += args.content_weight * mse_loss(
                    feature_y[layer], feature_x[layer])

            style_loss = 0.
            for name in style_layer:
                gm_y = utils.gram_matrix(feature_y[name])
                style_loss += mse_loss(gm_y, gram_style[name][:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()
            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)