Ejemplo n.º 1
0
        def closure():
            nonlocal cnt
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            total_loss, content_loss, style_loss, tv_loss, temporal_loss = build_loss(
                neural_net, optimizing_img, target_representations,
                content_feature_maps_index_name[0],
                style_feature_maps_indices_names[0], config, previous_img)
            if total_loss.requires_grad:
                total_loss.backward()
            with torch.no_grad():
                if temporal_loss != 0:
                    print(
                        f'L-BFGS | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}, temporal loss={config["temporal_weight"] * temporal_loss.item():12.4f}'
                    )
                else:
                    print(
                        f'L-BFGS | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}, temporal loss={config["temporal_weight"] * temporal_loss:12.4f}'
                    )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False,
                    out_img_name=out_img_name)

            cnt += 1
            return total_loss
 def closure():
     nonlocal cnt
     optimizer.zero_grad()
     loss = 0.0
     if should_reconstruct_content:
         loss = torch.nn.MSELoss(reduction='mean')(
             target_content_representation, neural_net(optimizing_img)[
                 content_feature_maps_index_name[0]].squeeze(axis=0))
     else:
         current_set_of_feature_maps = neural_net(optimizing_img)
         current_style_representation = [
             utils.gram_matrix(fmaps)
             for i, fmaps in enumerate(current_set_of_feature_maps)
             if i in style_feature_maps_indices_names[0]
         ]
         for gram_gt, gram_hat in zip(target_style_representation,
                                      current_style_representation):
             loss += (1 / len(target_style_representation)
                      ) * torch.nn.MSELoss(reduction='sum')(gram_gt[0],
                                                            gram_hat[0])
     loss.backward()
     with torch.no_grad():
         print(
             f'Iteration: {cnt}, current {"content" if should_reconstruct_content else "style"} loss={loss.item()}'
         )
         utils.save_and_maybe_display(
             optimizing_img,
             dump_path,
             config,
             cnt,
             num_of_iterations[config['optimizer']],
             should_display=False)
         cnt += 1
     return loss
def neural_style_transfer(config):
    content_img_path = os.path.join(config['content_images_dir'],
                                    config['content_img_name'])
    style_img_path = os.path.join(config['style_images_dir'],
                                  config['style_img_name'])

    out_dir_name = 'combined_' + os.path.split(content_img_path)[1].split(
        '.')[0] + '_' + os.path.split(style_img_path)[1].split('.')[0]
    dump_path = os.path.join(config['output_img_dir'], out_dir_name)
    os.makedirs(dump_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    content_img = utils.prepare_img(content_img_path, config['height'], device)
    style_img = utils.prepare_img(style_img_path, config['height'], device)

    if config['init_method'] == 'random':
        # white_noise_img = np.random.uniform(-90., 90., content_img.shape).astype(np.float32)
        gaussian_noise_img = np.random.normal(loc=0,
                                              scale=90.,
                                              size=content_img.shape).astype(
                                                  np.float32)
        init_img = torch.from_numpy(gaussian_noise_img).float().to(device)
    elif config['init_method'] == 'content':
        init_img = content_img
    else:
        # init image has same dimension as content image - this is a hard constraint
        # feature maps need to be of same size for content image and init image
        style_img_resized = utils.prepare_img(
            style_img_path, np.asarray(content_img.shape[2:]), device)
        init_img = style_img_resized

    # we are tuning optimizing_img's pixels! (that's why requires_grad=True)
    optimizing_img = Variable(init_img, requires_grad=True)

    neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model(
        config['model'], device)
    print(f'Using {config["model"]} in the optimization procedure.')

    content_img_set_of_feature_maps = neural_net(content_img)
    style_img_set_of_feature_maps = neural_net(style_img)

    target_content_representation = content_img_set_of_feature_maps[
        content_feature_maps_index_name[0]].squeeze(axis=0)
    target_style_representation = [
        utils.gram_matrix(x)
        for cnt, x in enumerate(style_img_set_of_feature_maps)
        if cnt in style_feature_maps_indices_names[0]
    ]
    target_representations = [
        target_content_representation, target_style_representation
    ]

    # magic numbers in general are a big no no - some things in this code are left like this by design to avoid clutter
    num_of_iterations = {
        "lbfgs": 1000,
        "adam": 3000,
    }

    #
    # Start of optimization procedure
    #
    if config['optimizer'] == 'adam':
        optimizer = Adam((optimizing_img, ), lr=1e1)
        tuning_step = make_tuning_step(neural_net, optimizer,
                                       target_representations,
                                       content_feature_maps_index_name[0],
                                       style_feature_maps_indices_names[0],
                                       config)
        for cnt in range(num_of_iterations[config['optimizer']]):
            total_loss, content_loss, style_loss, tv_loss = tuning_step(
                optimizing_img)
            with torch.no_grad():
                print(
                    f'Adam | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
    elif config['optimizer'] == 'lbfgs':
        # line_search_fn does not seem to have significant impact on result
        optimizer = LBFGS((optimizing_img, ),
                          max_iter=num_of_iterations['lbfgs'],
                          line_search_fn='strong_wolfe')
        cnt = 0

        def closure():
            nonlocal cnt
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            total_loss, content_loss, style_loss, tv_loss = build_loss(
                neural_net, optimizing_img, target_representations,
                content_feature_maps_index_name[0],
                style_feature_maps_indices_names[0], config)
            if total_loss.requires_grad:
                total_loss.backward()
            with torch.no_grad():
                print(
                    f'L-BFGS | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)

            cnt += 1
            return total_loss

        optimizer.step(closure)

    return dump_path
def reconstruct_image_from_representation(config):
    should_reconstruct_content = config['should_reconstruct_content']
    should_visualize_representation = config['should_visualize_representation']
    dump_path = os.path.join(config['output_img_dir'],
                             ('c' if should_reconstruct_content else 's') +
                             '_reconstruction_' + config['optimizer'])
    dump_path = os.path.join(
        dump_path, config['content_img_name'].split('.')[0] if
        should_reconstruct_content else config['style_img_name'].split('.')[0])
    os.makedirs(dump_path, exist_ok=True)

    content_img_path = os.path.join(config['content_images_dir'],
                                    config['content_img_name'])
    style_img_path = os.path.join(config['style_images_dir'],
                                  config['style_img_name'])
    img_path = content_img_path if should_reconstruct_content else style_img_path

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    img = utils.prepare_img(img_path, config['height'], device)

    gaussian_noise_img = np.random.normal(loc=0, scale=90.,
                                          size=img.shape).astype(np.float32)
    white_noise_img = np.random.uniform(-90., 90.,
                                        img.shape).astype(np.float32)
    init_img = torch.from_numpy(white_noise_img).float().to(device)
    optimizing_img = Variable(init_img, requires_grad=True)

    # indices pick relevant feature maps (say conv4_1, relu1_1, etc.)
    neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model(
        config['model'], device)

    # don't want to expose everything that's not crucial so some things are hardcoded
    num_of_iterations = {'adam': 3000, 'lbfgs': 350}

    set_of_feature_maps = neural_net(img)

    #
    # Visualize feature maps and Gram matrices (depending whether you're reconstructing content or style img)
    #
    if should_reconstruct_content:
        target_content_representation = set_of_feature_maps[
            content_feature_maps_index_name[0]].squeeze(axis=0)
        if should_visualize_representation:
            num_of_feature_maps = target_content_representation.size()[0]
            print(f'Number of feature maps: {num_of_feature_maps}')
            for i in range(num_of_feature_maps):
                feature_map = target_content_representation[i].to(
                    'cpu').numpy()
                feature_map = np.uint8(utils.get_uint8_range(feature_map))
                plt.imshow(feature_map)
                plt.title(
                    f'Feature map {i+1}/{num_of_feature_maps} from layer {content_feature_maps_index_name[1]} (model={config["model"]}) for {config["content_img_name"]} image.'
                )
                plt.show()
                filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
                utils.save_image(feature_map,
                                 os.path.join(dump_path, filename))
    else:
        target_style_representation = [
            utils.gram_matrix(fmaps)
            for i, fmaps in enumerate(set_of_feature_maps)
            if i in style_feature_maps_indices_names[0]
        ]
        if should_visualize_representation:
            num_of_gram_matrices = len(target_style_representation)
            print(f'Number of Gram matrices: {num_of_gram_matrices}')
            for i in range(num_of_gram_matrices):
                Gram_matrix = target_style_representation[i].squeeze(
                    axis=0).to('cpu').numpy()
                Gram_matrix = np.uint8(utils.get_uint8_range(Gram_matrix))
                plt.imshow(Gram_matrix)
                plt.title(
                    f'Gram matrix from layer {style_feature_maps_indices_names[1][i]} (model={config["model"]}) for {config["style_img_name"]} image.'
                )
                plt.show()
                filename = f'gram_{config["model"]}_{style_feature_maps_indices_names[1][i]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
                utils.save_image(Gram_matrix,
                                 os.path.join(dump_path, filename))

    #
    # Start of optimization procedure
    #
    if config['optimizer'] == 'adam':
        optimizer = Adam((optimizing_img, ))
        target_representation = target_content_representation if should_reconstruct_content else target_style_representation
        tuning_step = make_tuning_step(neural_net, optimizer,
                                       target_representation,
                                       should_reconstruct_content,
                                       content_feature_maps_index_name[0],
                                       style_feature_maps_indices_names[0])
        for it in range(num_of_iterations[config['optimizer']]):
            loss, _ = tuning_step(optimizing_img)
            with torch.no_grad():
                print(
                    f'Iteration: {it}, current {"content" if should_reconstruct_content else "style"} loss={loss:10.8f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    it,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
    elif config['optimizer'] == 'lbfgs':
        cnt = 0

        # closure is a function required by L-BFGS optimizer
        def closure():
            nonlocal cnt
            optimizer.zero_grad()
            loss = 0.0
            if should_reconstruct_content:
                loss = torch.nn.MSELoss(reduction='mean')(
                    target_content_representation, neural_net(optimizing_img)[
                        content_feature_maps_index_name[0]].squeeze(axis=0))
            else:
                current_set_of_feature_maps = neural_net(optimizing_img)
                current_style_representation = [
                    utils.gram_matrix(fmaps)
                    for i, fmaps in enumerate(current_set_of_feature_maps)
                    if i in style_feature_maps_indices_names[0]
                ]
                for gram_gt, gram_hat in zip(target_style_representation,
                                             current_style_representation):
                    loss += (1 / len(target_style_representation)
                             ) * torch.nn.MSELoss(reduction='sum')(gram_gt[0],
                                                                   gram_hat[0])
            loss.backward()
            with torch.no_grad():
                print(
                    f'Iteration: {cnt}, current {"content" if should_reconstruct_content else "style"} loss={loss.item()}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
                cnt += 1
            return loss

        optimizer = torch.optim.LBFGS(
            (optimizing_img, ),
            max_iter=num_of_iterations[config['optimizer']],
            line_search_fn='strong_wolfe')
        optimizer.step(closure)

    return dump_path