def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                               style_img, content_img,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)

    # normalization module
    normalization = Normalization(normalization_mean, normalization_std).to(device)

    # just in order to have an iterable access to or list of content/syle
    # losses
    content_losses = []
    style_losses = []

    # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
    # to put in modules that are supposed to be activated sequentially
    model = nn.Sequential(normalization)

    i = 0  # increment every time we see a conv
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            # The in-place version doesn't play very nicely with the ContentLoss
            # and StyleLoss we insert below. So we replace with out-of-place
            # ones here.
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, layer)

        if name in content_layers:
            # add content loss:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            # add style loss:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)

    # now we trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
            break

    model = model[:(i + 1)]

    return model, style_losses, content_losses
Exemplo n.º 2
0
def get_style_model_and_losses(cnn,
                               normalization_mean,
                               normalization_std,
                               style_img,
                               content_img,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)

    normalization = Normalization(normalization_mean, normalization_std)

    content_losses = []
    style_losses = []

    model = nn.Sequential(normalization)

    i = 0
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(
                layer.__class__.__name__))

        model.add_module(name, layer)

        if name in content_layers:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module('content_loss_{}'.format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module('style_loss_{}'.format(i), style_loss)
            style_losses.append(style_loss)

    for i in range(len(model), -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(
                model[i], StyleLoss):
            break

    model = model[:(i + 1)]

    return model, style_losses, content_losses
Exemplo n.º 3
0
def run_style_transfer(final_image, content_image, style_image, epochs=1):
    """ TODO """
    LOGGER.info(f"Loading the model and the losses")
    model = VGG().to(Utils.DEVICE)
    optimizer = optim.LBFGS([final_image.requires_grad_()])
    style_fns = {
        layer: StyleLoss(out).to(Utils.DEVICE)
        for layer, out in model(style_image, list(
            Utils.STYLE_LAYERS.keys())).items()
    }
    content_fns = {
        layer: ContentLoss(out).to(Utils.DEVICE)
        for layer, out in model(content_image, list(
            Utils.CONTENT_LAYERS.keys())).items()
    }
    LOGGER.debug(f"Loaded the model and the losses")

    i = 0
    progress_bar = tqdm.tqdm(total=epochs, desc='CL ? / SL ?', leave=True)

    def closure():
        final_image.data.clamp_(
            0, 1
        )  # image data may be updated with values outside 0 and 1 (boundaries)
        optimizer.zero_grad()
        outs = model(final_image, [*Utils.STYLE_LAYERS, *Utils.CONTENT_LAYERS])

        style_loss = sum([
            weight * style_fns[layer](outs[layer])
            for layer, weight in Utils.STYLE_LAYERS.items()
        ])
        content_loss = sum([
            weight * content_fns[layer](outs[layer])
            for layer, weight in Utils.CONTENT_LAYERS.items()
        ])
        loss = style_loss + content_loss
        loss.backward()
        progress_bar.set_description(
            f"CL {content_loss.item():3f} / SL {style_loss.item():4f}")
        progress_bar.refresh()
        return loss

    LOGGER.info('Starting the optimization loop')
    while i < epochs:
        optimizer.step(closure)
        progress_bar.update(1)
        progress_bar.refresh()
        i += 1
    progress_bar.close()
    LOGGER.debug('Ended the optimization loop')
    final_image.data.clamp_(0, 1)  # last out-of-the-loop boundary correction
Exemplo n.º 4
0
    def get_style_model_and_losses(self, model_dict, style_img, content_img):
        """
        get the losses.
        """
        self.cnn = copy.deepcopy(self.cnn)
        self.cnn = self.cnn.to(self.device)

        c_idx = 0
        r_idx = 0
        p_idx = 0
        # do some normalization!

        # list of losses in layers:
        content_losses = []
        style_losses = []
        tv_losses = []

        model = nn.Sequential()
        i = 0
        tv_mod = TVLoss(1e-3)
        model.add_module(str(len(model)), tv_mod)
        tv_losses.append(tv_mod)

        for layer in self.cnn.children():

            if isinstance(layer, nn.Conv2d):
                i += 1
                name = model_dict['conv'][c_idx]
                c_idx += 1

            elif isinstance(layer, nn.ReLU):
                name = model_dict['relu'][r_idx]
                r_idx += 1
                layer = nn.ReLU(inplace=True)

            elif isinstance(layer, nn.MaxPool2d):
                name = model_dict['pool'][p_idx]
                # layer = nn.AvgPool2d(kernel_size=2, stride=2)
                p_idx += 1

            elif isinstance(layer, nn.BatchNorm2d):
                name = f'bn_{i}'

            else:
                layer_name = layer.__class__.__name__
                raise RuntimeError(f'Unrecognized Layer: {layer_name}')

            model.add_module(name, layer)

            if name in self.content_layers_default:
                content_loss = ContentLoss()
                model.add_module(f'content_loss_{i}', content_loss)
                content_losses.append(content_loss)

            if name in self.style_layers_default:
                # get feature maps from style
                style_loss = StyleLoss()
                model.add_module(f'style_loss_{i}', style_loss)
                style_losses.append(style_loss)

        # now we trim off the layers after the last content and style losses
        # if there is extra non-needed layers.
        for i in range(len(model) - 1, -1, -1):
            if isinstance(model[i], ContentLoss) or isinstance(
                    model[i], StyleLoss):
                break

        model = model[:(i + 1)]
        # rip through the model, getting the REAL feature maps for style and content.
        for module in style_losses:
            module.mode = "capture"

        model(style_img)

        for module in style_losses:
            module.mode = "none"

        for module in content_losses:
            module.mode = "capture"

        model(content_img)

        for module in style_losses:
            module.mode = "loss"

        for module in content_losses:
            module.mode = "loss"

        return model, style_losses, content_losses, tv_losses
Exemplo n.º 5
0
def add_modules(cnn, mean, std, img, layers, device, replace=False):
    """
    Modifiy the model to integrate new modules.

    Inputs
    ------
    - cnn     : a convolutional neurak network (nn.Module)
    - mean    : the mean normalization vector
    - std     : the standard deviation normalization vector
    - img     : a dictionary with content and style image
    - layers  : a dictionary with lists of content and style layers to add
    - replace : a flag to determine if we have to replace MaxPool by AvgPool
    """

    # Copy the CNN
    cnn_copy = copy.deepcopy(cnn)

    # Create the normalization module
    norm_module = Normalization(mean, std).to(device)

    # Initializes losses lists
    content_losses = []
    style_losses = []

    # Create the new model with the normalization module
    # (in order to normalize input images)
    model = nn.Sequential(norm_module)

    # Iterate over each layer of the CNN
    i = 0

    for layer in cnn_copy.children():
        if isinstance(layer, nn.Conv2d):
            i += 1

            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            layer = nn.ReLU(inplace=False)

            name = 'relu_{}'.format(i)
        elif isinstance(layer, nn.MaxPool2d):
            # We replace 'MaxPool' layer by 'AvgPool' layer, as suggested by the author
            if replace:
                layer = nn.AvgPool2d(2, 2)

            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer : {}'.format(
                layer.__class__.__name__))

        # Add the layer to our model
        model.add_module(name, layer)

        # Add the content layers at the right place
        if name in layers['content']:
            reference = model(img['content']).detach()
            content_loss = ContentLoss(reference)
            model.add_module('content_loss_{}'.format(i), content_loss)
            content_losses.append(content_loss)

        # Add the style layers at the right place
        if name in layers['style']:
            reference = model(img['style']).detach()
            style_loss = StyleLoss(reference)
            model.add_module('style_loss_{}'.format(i), style_loss)
            style_losses.append(style_loss)

    # Remove the layers after the last content and style ones
    for i in range(len(model) - 1, -1, -1):
        m = model[i]

        if isinstance(m, ContentLoss) or isinstance(m, StyleLoss):
            break

    model = model[:(i + 1)]

    return model, {'style': style_losses, 'content': content_losses}
Exemplo n.º 6
0
def run_style_transfer(image_size,
                       content_image_path,
                       style_image_path,
                       content_layers_weights,
                       style_layers_weights,
                       variation_weight,
                       n_steps,
                       shifting_activation_value,
                       device_name,
                       preserve_colors):
    print('Transfer style to content image')
    print('Number of iterations: %s' % n_steps)
    print('Preserve colors: %s' % preserve_colors)
    print('--------------------------------')
    print('Content image path: %s' % content_image_path)
    print('Style image path: %s' % style_image_path)
    print('--------------------------------')
    print('Content layers: %s' % content_layers_weights.keys())
    print('Content weight: %s' % style_layers_weights.keys())
    print('Style layers: %s' % content_layers_weights.values())
    print('Style weight: %s' % style_layers_weights.values())
    print('Variation weight: %s' % variation_weight)
    print('--------------------------------')
    print('Shifting activation value: %s' % shifting_activation_value)
    print('--------------------------------\n\n')

    device = torch.device("cuda" if (torch.cuda.is_available() and device_name == 'cuda') else "cpu")

    image_handler = ImageHandler(image_size=image_size,
                                 content_image_path=content_image_path,
                                 style_image_path=style_image_path,
                                 device=device,
                                 preserve_colors=preserve_colors)
    content_layer_names = list(content_layers_weights.keys())
    style_layer_names = list(style_layers_weights.keys())
    layer_names = content_layer_names + style_layer_names

    last_layer = get_last_used_conv_layer(layer_names)
    model = transfer_vgg19(last_layer, device)

    print('--------------------------------')
    print('Model:')
    print(model)
    print('--------------------------------')
    content_features = model(image_handler.content_image, content_layer_names)
    content_losses = {layer_name: ContentLoss(weight=weight)
                      for layer_name, weight in content_layers_weights.items()}

    style_features = model(image_handler.style_image, style_layer_names)
    style_losses = {layer_name: StyleLoss(weight=weight,
                                          shifting_activation_value=shifting_activation_value)
                    for layer_name, weight in style_layers_weights.items()}

    variation_loss = VariationLoss(weight=variation_weight)

    combination_image = image_handler.content_image.clone()
    optimizer = optim.LBFGS([combination_image.requires_grad_()])
    run = [0]
    while run[0] <= n_steps:
        def closure():
            # correct the values of updated input image
            combination_image.data.clamp_(0, 1)

            optimizer.zero_grad()
            out = model(combination_image, layer_names)
            variation_score = variation_loss(combination_image)
            content_score = torch.sum(torch.stack([loss(out[layer_name], content_features[layer_name].detach())
                                                   for layer_name, loss in content_losses.items()]))
            style_score = torch.sum(torch.stack([loss(out[layer_name], style_features[layer_name].detach())
                                                 for layer_name, loss in style_losses.items()]))

            loss = style_score + content_score + variation_score
            loss.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f} Variation Loss: {:4f}'.format(
                    style_score.item(), content_score.item(), variation_score.item()))

            return loss
        optimizer.step(closure)

        # a last correction...
    combination_image.data.clamp_(0, 1)

    plt.figure()
    image_handler.imshow(combination_image, title='Output Image')
    plt.show()
    return image_handler.image_unloader(combination_image)
Exemplo n.º 7
0
def train(image_size, style_image_path, dataset_path, model_path,
          content_layers_weights, style_layers_weights, variation_weight,
          shifting_activation_value, batch_size, learning_rate, epochs,
          device_name):
    print('Train style model')
    print('Number of epochs: %s' % epochs)
    print('Leaning rate: %s' % learning_rate)
    print('Batch size: %s' % batch_size)
    print('--------------------------------')
    print('Style image path: %s' % style_image_path)
    print('Dataset path: %s' % dataset_path)
    print('--------------------------------')
    print('Content layers: %s' % content_layers_weights.keys())
    print('Content weight: %s' % style_layers_weights.keys())
    print('Style layers: %s' % content_layers_weights.values())
    print('Style weight: %s' % style_layers_weights.values())
    print('Variation weight: %s' % variation_weight)
    print('--------------------------------')
    print('Shifting activation value: %s' % shifting_activation_value)
    print('--------------------------------\n\n')

    device = torch.device("cuda" if (
        torch.cuda.is_available() and device_name == 'cuda') else "cpu")

    content_layer_names = list(content_layers_weights.keys())
    style_layer_names = list(style_layers_weights.keys())
    layer_names = content_layer_names + style_layer_names

    last_layer = get_last_used_conv_layer(layer_names)
    vgg_model = transfer_vgg19(last_layer, device)

    image_handler = TrainStyleImageHandler(image_size, style_image_path,
                                           dataset_path, batch_size, device)

    transform_model = TransformNetwork().to(device)
    transform_model.train()
    optimizer = Adam(transform_model.parameters(), learning_rate)

    content_losses = {
        layer_name: ContentLoss(weight=weight)
        for layer_name, weight in content_layers_weights.items()
    }
    style_losses = {
        layer_name:
        StyleLoss(weight=weight,
                  shifting_activation_value=shifting_activation_value)
        for layer_name, weight in style_layers_weights.items()
    }
    variation_loss = VariationLoss(weight=variation_weight)

    print('Start training')
    for epoch in range(epochs):
        print('--------------------')
        print('Epoch %s' % epoch)
        print('--------------------')
        epoch_content_score = 0
        epoch_style_score = 0
        epoch_variation_score = 0
        for batch_id, (batch, _) in enumerate(image_handler.train_loader):
            optimizer.zero_grad()
            transform_batch = transform_model(batch.to(device))

            transform_batch_features = vgg_model(transform_batch, layer_names)
            batch_features = vgg_model(batch.to(device), layer_names)
            style_features = vgg_model(image_handler.style_image, layer_names)
            variation_score = variation_loss(transform_batch)
            content_score = torch.sum(
                torch.stack([
                    loss(transform_batch_features[layer_name],
                         batch_features[layer_name])
                    for layer_name, loss in content_losses.items()
                ]))
            style_score = torch.sum(
                torch.stack([
                    loss(transform_batch_features[layer_name],
                         style_features[layer_name])
                    for layer_name, loss in style_losses.items()
                ]))

            loss = style_score + content_score + variation_score
            loss.backward()
            optimizer.step()

            epoch_content_score += content_score
            epoch_style_score += style_score
            epoch_variation_score += variation_score

            if batch_id % 50 == 0:
                print("Batch {}:".format(batch_id))
                print(
                    'Style Loss : {:4f} Content Loss: {:4f} Variation Loss: {:4f}'
                    .format(style_score.item(), content_score.item(),
                            variation_score.item()))

        print(
            'Epoch summary \n Style Loss : {:4f} Content Loss: {:4f} Variation Loss: {:4f}'
            .format(epoch_style_score.item(), epoch_content_score.item(),
                    epoch_variation_score.item()))

    transform_model.eval().cpu()
    torch.save(transform_model.state_dict(), model_path)
Exemplo n.º 8
0
def get_style_model_and_losses(
    cnn,
    normalization_mean,
    normalization_std,
    style_img,
    content_img,
    content_layers=content_layers,
    style_layers=style_layers,
):
    cnn = copy.deepcopy(cnn)

    # normalization
    normalization = Normalization(normalization_mean, normalization_std).to(device)

    # losses container
    content_losses = []
    style_losses = []

    model = nn.Sequential(normalization)

    i = 0  # conv tracker
    for layer in cnn.children():

        # resnet accomdation
        nonseq_layer = 1
        if isinstance(layer, nn.Sequential):
            nonseq_layer = 0
            if isinstance(layer[0], torchvision.models.resnet.BasicBlock):
                for l in layer[0].children():
                    if isinstance(layer, nn.Conv2d):
                        i += 1
                        name = "conv_{}".format(i)
                    elif isinstance(layer, nn.ReLU):
                        name = "relu_{}".format(i)
                        layer = nn.ReLU(inplace=False)
                    elif isinstance(layer, nn.MaxPool2d):
                        name = "pool_{}".format(i)
                    elif isinstance(layer, nn.BatchNorm2d):
                        name = "bn_{}".format(i)
                    model.add_module(name, layer)
        elif isinstance(layer, nn.Conv2d):
            i += 1
            name = "conv_{}".format(i)
        elif isinstance(layer, nn.ReLU):
            name = "relu_{}".format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = "pool_{}".format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = "bn_{}".format(i)


        if nonseq_layer:
            model.add_module(name, layer)

        if name in content_layers:
            # add content loss:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            # add style loss:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)


    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
            break

    model = model[: (i + 1)]

    return model, style_losses, content_losses