def run_style_transfer(image_size, content_image_path, style_image_path, content_layers_weights, style_layers_weights, variation_weight, n_steps, shifting_activation_value, device_name, preserve_colors): print('Transfer style to content image') print('Number of iterations: %s' % n_steps) print('Preserve colors: %s' % preserve_colors) print('--------------------------------') print('Content image path: %s' % content_image_path) print('Style image path: %s' % style_image_path) print('--------------------------------') print('Content layers: %s' % content_layers_weights.keys()) print('Content weight: %s' % style_layers_weights.keys()) print('Style layers: %s' % content_layers_weights.values()) print('Style weight: %s' % style_layers_weights.values()) print('Variation weight: %s' % variation_weight) print('--------------------------------') print('Shifting activation value: %s' % shifting_activation_value) print('--------------------------------\n\n') device = torch.device("cuda" if (torch.cuda.is_available() and device_name == 'cuda') else "cpu") image_handler = ImageHandler(image_size=image_size, content_image_path=content_image_path, style_image_path=style_image_path, device=device, preserve_colors=preserve_colors) content_layer_names = list(content_layers_weights.keys()) style_layer_names = list(style_layers_weights.keys()) layer_names = content_layer_names + style_layer_names last_layer = get_last_used_conv_layer(layer_names) model = transfer_vgg19(last_layer, device) print('--------------------------------') print('Model:') print(model) print('--------------------------------') content_features = model(image_handler.content_image, content_layer_names) content_losses = {layer_name: ContentLoss(weight=weight) for layer_name, weight in content_layers_weights.items()} style_features = model(image_handler.style_image, style_layer_names) style_losses = {layer_name: StyleLoss(weight=weight, shifting_activation_value=shifting_activation_value) for layer_name, weight in style_layers_weights.items()} variation_loss = VariationLoss(weight=variation_weight) combination_image = image_handler.content_image.clone() optimizer = optim.LBFGS([combination_image.requires_grad_()]) run = [0] while run[0] <= n_steps: def closure(): # correct the values of updated input image combination_image.data.clamp_(0, 1) optimizer.zero_grad() out = model(combination_image, layer_names) variation_score = variation_loss(combination_image) content_score = torch.sum(torch.stack([loss(out[layer_name], content_features[layer_name].detach()) for layer_name, loss in content_losses.items()])) style_score = torch.sum(torch.stack([loss(out[layer_name], style_features[layer_name].detach()) for layer_name, loss in style_losses.items()])) loss = style_score + content_score + variation_score loss.backward() run[0] += 1 if run[0] % 50 == 0: print("run {}:".format(run)) print('Style Loss : {:4f} Content Loss: {:4f} Variation Loss: {:4f}'.format( style_score.item(), content_score.item(), variation_score.item())) return loss optimizer.step(closure) # a last correction... combination_image.data.clamp_(0, 1) plt.figure() image_handler.imshow(combination_image, title='Output Image') plt.show() return image_handler.image_unloader(combination_image)