def stylize_static_image(inference_config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    content_img_path = os.path.join(inference_config['content_images_path'], inference_config['content_img_name'])
    content_image = utils.prepare_img(content_img_path, inference_config['img_width'], device)

    # load the weights and set the model to evaluation mode
    stylization_model = TransformerNet().to(device)
    training_state = torch.load(os.path.join(inference_config["model_binaries_path"], inference_config["model_name"]))
    utils.print_model_metadata(training_state)
    state_dict = training_state["state_dict"]
    stylization_model.load_state_dict(state_dict, strict=True)
    stylization_model.eval()

    with torch.no_grad():
        stylized_img = stylization_model(content_image).to('cpu').numpy()[0]
        utils.save_and_maybe_display_image(inference_config, stylized_img, should_display=True)
def neural_style_transfer(config):
    content_img_path = os.path.join(config['content_images_dir'],
                                    config['content_img_name'])
    style_img_path = os.path.join(config['style_images_dir'],
                                  config['style_img_name'])

    out_dir_name = 'combined_' + os.path.split(content_img_path)[1].split(
        '.')[0] + '_' + os.path.split(style_img_path)[1].split('.')[0]
    dump_path = os.path.join(config['output_img_dir'], out_dir_name)
    os.makedirs(dump_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    content_img = utils.prepare_img(content_img_path, config['height'], device)
    style_img = utils.prepare_img(style_img_path, config['height'], device)

    if config['init_method'] == 'random':
        # white_noise_img = np.random.uniform(-90., 90., content_img.shape).astype(np.float32)
        gaussian_noise_img = np.random.normal(loc=0,
                                              scale=90.,
                                              size=content_img.shape).astype(
                                                  np.float32)
        init_img = torch.from_numpy(gaussian_noise_img).float().to(device)
    elif config['init_method'] == 'content':
        init_img = content_img
    else:
        # init image has same dimension as content image - this is a hard constraint
        # feature maps need to be of same size for content image and init image
        style_img_resized = utils.prepare_img(
            style_img_path, np.asarray(content_img.shape[2:]), device)
        init_img = style_img_resized

    # we are tuning optimizing_img's pixels! (that's why requires_grad=True)
    optimizing_img = Variable(init_img, requires_grad=True)

    neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model(
        config['model'], device)
    print(f'Using {config["model"]} in the optimization procedure.')

    content_img_set_of_feature_maps = neural_net(content_img)
    style_img_set_of_feature_maps = neural_net(style_img)

    target_content_representation = content_img_set_of_feature_maps[
        content_feature_maps_index_name[0]].squeeze(axis=0)
    target_style_representation = [
        utils.gram_matrix(x)
        for cnt, x in enumerate(style_img_set_of_feature_maps)
        if cnt in style_feature_maps_indices_names[0]
    ]
    target_representations = [
        target_content_representation, target_style_representation
    ]

    # magic numbers in general are a big no no - some things in this code are left like this by design to avoid clutter
    num_of_iterations = {
        "lbfgs": 1000,
        "adam": 3000,
    }

    #
    # Start of optimization procedure
    #
    if config['optimizer'] == 'adam':
        optimizer = Adam((optimizing_img, ), lr=1e1)
        tuning_step = make_tuning_step(neural_net, optimizer,
                                       target_representations,
                                       content_feature_maps_index_name[0],
                                       style_feature_maps_indices_names[0],
                                       config)
        for cnt in range(num_of_iterations[config['optimizer']]):
            total_loss, content_loss, style_loss, tv_loss = tuning_step(
                optimizing_img)
            with torch.no_grad():
                print(
                    f'Adam | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
    elif config['optimizer'] == 'lbfgs':
        # line_search_fn does not seem to have significant impact on result
        optimizer = LBFGS((optimizing_img, ),
                          max_iter=num_of_iterations['lbfgs'],
                          line_search_fn='strong_wolfe')
        cnt = 0

        def closure():
            nonlocal cnt
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            total_loss, content_loss, style_loss, tv_loss = build_loss(
                neural_net, optimizing_img, target_representations,
                content_feature_maps_index_name[0],
                style_feature_maps_indices_names[0], config)
            if total_loss.requires_grad:
                total_loss.backward()
            with torch.no_grad():
                print(
                    f'L-BFGS | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)

            cnt += 1
            return total_loss

        optimizer.step(closure)

    return dump_path
def stylize_static_image(inference_config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare the model - load the weights and put the model into evaluation mode
    stylization_model = TransformerNet().to(device)
    training_state = torch.load(
        os.path.join(inference_config["model_binaries_path"],
                     inference_config["model_name"]))
    state_dict = training_state["state_dict"]
    stylization_model.load_state_dict(state_dict, strict=True)
    stylization_model.eval()

    if inference_config['verbose']:
        utils.print_model_metadata(training_state)

    with torch.no_grad():
        if os.path.isdir(
                inference_config['content_input']
        ):  # do a batch stylization (every image in the directory)
            img_dataset = utils.SimpleDataset(
                inference_config['content_input'],
                inference_config['img_width'])
            img_loader = DataLoader(img_dataset,
                                    batch_size=inference_config['batch_size'])

            try:
                processed_imgs_cnt = 0
                for batch_id, img_batch in enumerate(img_loader):
                    processed_imgs_cnt += len(img_batch)
                    if inference_config['verbose']:
                        print(
                            f'Processing batch {batch_id + 1} ({processed_imgs_cnt}/{len(img_dataset)} processed images).'
                        )

                    img_batch = img_batch.to(device)
                    stylized_imgs = stylization_model(img_batch).to(
                        'cpu').numpy()
                    for stylized_img in stylized_imgs:
                        utils.save_and_maybe_display_image(
                            inference_config,
                            stylized_img,
                            should_display=False)
            except Exception as e:
                print(e)
                print(
                    f'Consider making the batch_size (current = {inference_config["batch_size"]} images) or img_width (current = {inference_config["img_width"]} px) smaller'
                )
                exit(1)

        else:  # do stylization for a single image
            content_img_path = os.path.join(
                inference_config['content_images_path'],
                inference_config['content_input'])
            content_image = utils.prepare_img(content_img_path,
                                              inference_config['img_width'],
                                              device)
            stylized_img = stylization_model(content_image).to(
                'cpu').numpy()[0]
            utils.save_and_maybe_display_image(
                inference_config,
                stylized_img,
                should_display=inference_config['should_not_display'])
def reconstruct_image_from_representation(config):
    should_reconstruct_content = config['should_reconstruct_content']
    should_visualize_representation = config['should_visualize_representation']
    dump_path = os.path.join(config['output_img_dir'],
                             ('c' if should_reconstruct_content else 's') +
                             '_reconstruction_' + config['optimizer'])
    dump_path = os.path.join(
        dump_path, config['content_img_name'].split('.')[0] if
        should_reconstruct_content else config['style_img_name'].split('.')[0])
    os.makedirs(dump_path, exist_ok=True)

    content_img_path = os.path.join(config['content_images_dir'],
                                    config['content_img_name'])
    style_img_path = os.path.join(config['style_images_dir'],
                                  config['style_img_name'])
    img_path = content_img_path if should_reconstruct_content else style_img_path

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    img = utils.prepare_img(img_path, config['height'], device)

    gaussian_noise_img = np.random.normal(loc=0, scale=90.,
                                          size=img.shape).astype(np.float32)
    white_noise_img = np.random.uniform(-90., 90.,
                                        img.shape).astype(np.float32)
    init_img = torch.from_numpy(white_noise_img).float().to(device)
    optimizing_img = Variable(init_img, requires_grad=True)

    # indices pick relevant feature maps (say conv4_1, relu1_1, etc.)
    neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model(
        config['model'], device)

    # don't want to expose everything that's not crucial so some things are hardcoded
    num_of_iterations = {'adam': 3000, 'lbfgs': 350}

    set_of_feature_maps = neural_net(img)

    #
    # Visualize feature maps and Gram matrices (depending whether you're reconstructing content or style img)
    #
    if should_reconstruct_content:
        target_content_representation = set_of_feature_maps[
            content_feature_maps_index_name[0]].squeeze(axis=0)
        if should_visualize_representation:
            num_of_feature_maps = target_content_representation.size()[0]
            print(f'Number of feature maps: {num_of_feature_maps}')
            for i in range(num_of_feature_maps):
                feature_map = target_content_representation[i].to(
                    'cpu').numpy()
                feature_map = np.uint8(utils.get_uint8_range(feature_map))
                plt.imshow(feature_map)
                plt.title(
                    f'Feature map {i+1}/{num_of_feature_maps} from layer {content_feature_maps_index_name[1]} (model={config["model"]}) for {config["content_img_name"]} image.'
                )
                plt.show()
                filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
                utils.save_image(feature_map,
                                 os.path.join(dump_path, filename))
    else:
        target_style_representation = [
            utils.gram_matrix(fmaps)
            for i, fmaps in enumerate(set_of_feature_maps)
            if i in style_feature_maps_indices_names[0]
        ]
        if should_visualize_representation:
            num_of_gram_matrices = len(target_style_representation)
            print(f'Number of Gram matrices: {num_of_gram_matrices}')
            for i in range(num_of_gram_matrices):
                Gram_matrix = target_style_representation[i].squeeze(
                    axis=0).to('cpu').numpy()
                Gram_matrix = np.uint8(utils.get_uint8_range(Gram_matrix))
                plt.imshow(Gram_matrix)
                plt.title(
                    f'Gram matrix from layer {style_feature_maps_indices_names[1][i]} (model={config["model"]}) for {config["style_img_name"]} image.'
                )
                plt.show()
                filename = f'gram_{config["model"]}_{style_feature_maps_indices_names[1][i]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
                utils.save_image(Gram_matrix,
                                 os.path.join(dump_path, filename))

    #
    # Start of optimization procedure
    #
    if config['optimizer'] == 'adam':
        optimizer = Adam((optimizing_img, ))
        target_representation = target_content_representation if should_reconstruct_content else target_style_representation
        tuning_step = make_tuning_step(neural_net, optimizer,
                                       target_representation,
                                       should_reconstruct_content,
                                       content_feature_maps_index_name[0],
                                       style_feature_maps_indices_names[0])
        for it in range(num_of_iterations[config['optimizer']]):
            loss, _ = tuning_step(optimizing_img)
            with torch.no_grad():
                print(
                    f'Iteration: {it}, current {"content" if should_reconstruct_content else "style"} loss={loss:10.8f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    it,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
    elif config['optimizer'] == 'lbfgs':
        cnt = 0

        # closure is a function required by L-BFGS optimizer
        def closure():
            nonlocal cnt
            optimizer.zero_grad()
            loss = 0.0
            if should_reconstruct_content:
                loss = torch.nn.MSELoss(reduction='mean')(
                    target_content_representation, neural_net(optimizing_img)[
                        content_feature_maps_index_name[0]].squeeze(axis=0))
            else:
                current_set_of_feature_maps = neural_net(optimizing_img)
                current_style_representation = [
                    utils.gram_matrix(fmaps)
                    for i, fmaps in enumerate(current_set_of_feature_maps)
                    if i in style_feature_maps_indices_names[0]
                ]
                for gram_gt, gram_hat in zip(target_style_representation,
                                             current_style_representation):
                    loss += (1 / len(target_style_representation)
                             ) * torch.nn.MSELoss(reduction='sum')(gram_gt[0],
                                                                   gram_hat[0])
            loss.backward()
            with torch.no_grad():
                print(
                    f'Iteration: {cnt}, current {"content" if should_reconstruct_content else "style"} loss={loss.item()}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
                cnt += 1
            return loss

        optimizer = torch.optim.LBFGS(
            (optimizing_img, ),
            max_iter=num_of_iterations[config['optimizer']],
            line_search_fn='strong_wolfe')
        optimizer.step(closure)

    return dump_path
def train(training_config):
    writer = SummaryWriter(
    )  # (tensorboard) writer will output to ./runs/ directory by default
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare data loader
    train_loader = utils.get_training_data_loader(training_config)

    # prepare neural networks
    transformer_net = TransformerNet().train().to(device)
    perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device)

    optimizer = LBFGS(transformer_net.parameters(),
                      line_search_fn='strong_wolfe')

    # Calculate style image's Gram matrices (style representation)
    # Built over feature maps as produced by the perceptual net - VGG16
    style_img_path = os.path.join(training_config['style_images_path'],
                                  training_config['style_img_name'])
    style_img = utils.prepare_img(style_img_path,
                                  target_shape=None,
                                  device=device,
                                  batch_size=training_config['batch_size'])
    style_img_set_of_feature_maps = perceptual_loss_net(style_img)
    target_style_representation = [
        utils.gram_matrix(x) for x in style_img_set_of_feature_maps
    ]

    utils.print_header(training_config)
    # Tracking loss metrics, NST is ill-posed we can only track loss and visual appearance of the stylized images
    acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]
    ts = time.time()
    for epoch in range(training_config['num_of_epochs']):
        for batch_id, (content_batch, _) in enumerate(train_loader):
            # step1: Feed content batch through transformer net
            content_batch = content_batch.to(device)
            stylized_batch = transformer_net(content_batch)

            # step2: Feed content and stylized batch through perceptual net (VGG16)
            content_batch_set_of_feature_maps = perceptual_loss_net(
                content_batch)
            stylized_batch_set_of_feature_maps = perceptual_loss_net(
                stylized_batch)

            # step3: Calculate content representations and content loss
            target_content_representation = content_batch_set_of_feature_maps.relu2_2
            current_content_representation = stylized_batch_set_of_feature_maps.relu2_2
            content_loss = training_config['content_weight'] * torch.nn.MSELoss(
                reduction='mean')(target_content_representation,
                                  current_content_representation)

            # step4: Calculate style representation and style loss
            style_loss = 0.0
            current_style_representation = [
                utils.gram_matrix(x)
                for x in stylized_batch_set_of_feature_maps
            ]
            for gram_gt, gram_hat in zip(target_style_representation,
                                         current_style_representation):
                style_loss += torch.nn.MSELoss(reduction='mean')(gram_gt,
                                                                 gram_hat)
            style_loss /= len(target_style_representation)
            style_loss *= training_config['style_weight']

            # step5: Calculate total variation loss - enforces image smoothness
            tv_loss = training_config['tv_weight'] * utils.total_variation(
                stylized_batch)

            # step6: Combine losses and do a backprop
            total_loss = content_loss + style_loss + tv_loss
            total_loss.backward()

            def closure():
                nonlocal total_loss
                optimizer.zero_grad()
                return total_loss

            optimizer.step(closure)

            #
            # Logging and checkpoint creation
            #
            acc_content_loss += content_loss.item()
            acc_style_loss += style_loss.item()
            acc_tv_loss += tv_loss.item()

            if training_config['enable_tensorboard']:
                # log scalars
                writer.add_scalar('Loss/content-loss', content_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalar('Loss/style-loss', style_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalar('Loss/tv-loss', tv_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalars(
                    'Statistics/min-max-mean-median', {
                        'min': torch.min(stylized_batch),
                        'max': torch.max(stylized_batch),
                        'mean': torch.mean(stylized_batch),
                        'median': torch.median(stylized_batch)
                    },
                    len(train_loader) * epoch + batch_id + 1)
                # log stylized image
                if batch_id % training_config['image_log_freq'] == 0:
                    stylized = utils.post_process_image(
                        stylized_batch[0].detach().to('cpu').numpy())
                    stylized = np.moveaxis(
                        stylized, 2, 0)  # writer expects channel first image
                    writer.add_image('stylized_img', stylized,
                                     len(train_loader) * epoch + batch_id + 1)

            if training_config[
                    'console_log_freq'] is not None and batch_id % training_config[
                        'console_log_freq'] == 0:
                print(
                    f'time elapsed={(time.time() - ts) / 60:.2f}[min]|epoch={epoch + 1}|batch=[{batch_id + 1}/{len(train_loader)}]|c-loss={acc_content_loss / training_config["console_log_freq"]}|s-loss={acc_style_loss / training_config["console_log_freq"]}|tv-loss={acc_tv_loss / training_config["console_log_freq"]}|total loss={(acc_content_loss + acc_style_loss + acc_tv_loss) / training_config["console_log_freq"]}'
                )
                acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]

            if training_config['checkpoint_freq'] is not None and (
                    batch_id + 1) % training_config['checkpoint_freq'] == 0:
                training_state = utils.get_training_metadata(training_config)
                training_state["state_dict"] = transformer_net.state_dict()
                training_state["optimizer_state"] = optimizer.state_dict()
                ckpt_model_name = f"ckpt_style_{training_config['style_img_name'].split('.')[0]}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}_epoch_{epoch}_batch_{batch_id}.pth"
                torch.save(
                    training_state,
                    os.path.join(training_config['checkpoints_path'],
                                 ckpt_model_name))

    #
    # Save model with additional metadata - like which commit was used to train the model, style/content weights, etc.
    #
    training_state = utils.get_training_metadata(training_config)
    training_state["state_dict"] = transformer_net.state_dict()
    training_state["optimizer_state"] = optimizer.state_dict()
    model_name = f"style_{training_config['style_img_name'].split('.')[0]}_datapoints_{training_state['num_of_datapoints']}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}.pth"
    torch.save(
        training_state,
        os.path.join(training_config['model_binaries_path'], model_name))