Ejemplo n.º 1
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image,
                                     scale=args.content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(
                    ".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image,
                                            args.export_onnx).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
Ejemplo n.º 2
0
def stylize(input_filepath, output_filepath, checkpoint, content_scale=None):
    content_image = utils.load_image(input_filepath, scale=content_scale)
    content_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0)
    if CUDA:
        content_image = content_image.cuda()
    content_image = Variable(content_image, volatile=True)

    style_model = TransformerNet()
    style_model.load_state_dict(torch.load(checkpoint))
    if CUDA:
        style_model.cuda()
    output = style_model(content_image)
    if CUDA:
        output = output.cpu()
    output_data = output.data[0]
    utils.save_image(output_filepath, output_data)
Ejemplo n.º 3
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Ejemplo n.º 4
0
def train(args):
    log(json.dumps({"type": "status_update", "status": "Setting up training"}))
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])

    log(json.dumps({"type": "status_update", "status": "Loading dataset"}))

    train_dataset = datasets.ImageFolder(args.dataset, transform)
    log(
        json.dumps({
            "type": "dataset_info",
            "dataset_length": len(train_dataset) * args.epochs
        }))
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    log(json.dumps({"type": "status_update", "status": "Dataset loaded"}))

    log(json.dumps({"type": "status_update", "status": "Loading image"}))

    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    log(json.dumps({"type": "status_update", "status": "Image loaded"}))

    log(json.dumps({"type": "status_update", "status": "Training setup done"}))

    progress_count = 0

    log(json.dumps({"type": "status_update", "status": "Starting training"}))

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = args.content_weight * mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            log(
                json.dumps({
                    "type":
                    "training_progress",
                    "progress":
                    str(progress_count),
                    "percent":
                    str(
                        round(
                            progress_count /
                            (len(train_dataset) * args.epochs) * 100, 2))
                }))

            progress_count = progress_count + args.batch_size

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                if args.name is None:
                    ckpt_model_filename = str(
                        os.path.normpath(os.path.basename(
                            args.style_image))[0:int(
                                os.path.
                                normpath(os.path.basename(args.style_image)).
                                rfind("."))]) + "_" + str(batch_id +
                                                          1) + ".pth"
                else:
                    ckpt_model_filename = str(
                        args.name) + "_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    log(json.dumps({"type": "status_update", "status": "training done"}))

    # save model
    log(json.dumps({"type": "status_update", "status": "saving model"}))
    transformer.eval().cpu()
    if args.name is None:
        save_model_filename = str(
            os.path.normpath(os.path.basename(args.style_image))[0:int(
                os.path.normpath(os.path.basename(args.style_image)).rfind(".")
            )]) + ".pth"
    else:
        save_model_filename = str(args.name + ".pth")

    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    log(json.dumps({"type": "status_update", "status": "model saved"}))
Ejemplo n.º 5
0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## put your original images in this directory
input_dir = Path('./inputs')

with torch.no_grad():
    style_model = TransformerNet()
    state_dict = torch.load('nausicaa3.model')
    for k in list(state_dict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del state_dict[k]
    style_model.load_state_dict(state_dict)
    style_model.to(device)

    for imagefile in input_dir.glob('*.jpg'):

        content_image = utils.load_image(imagefile, scale=None)
        content_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])
        content_image = content_transform(content_image)
        content_image = content_image.unsqueeze(0).to(device)

        output = style_model(content_image).cpu()

        # styled images will be saved to this directory
        utils.save_image(f'./outputs/{imagefile.name}', output[0])