def build_loss(neural_net, optimizing_img, target_representations,
               content_feature_maps_index, style_feature_maps_indices, config):
    target_content_representation = target_representations[0]
    target_style_representation = target_representations[1]

    current_set_of_feature_maps = neural_net(optimizing_img)

    current_content_representation = current_set_of_feature_maps[
        content_feature_maps_index].squeeze(axis=0)
    content_loss = torch.nn.MSELoss(reduction='mean')(
        target_content_representation, current_content_representation)

    style_loss = 0.0
    current_style_representation = [
        utils.gram_matrix(x)
        for cnt, x in enumerate(current_set_of_feature_maps)
        if cnt in style_feature_maps_indices
    ]
    for gram_gt, gram_hat in zip(target_style_representation,
                                 current_style_representation):
        style_loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0],
                                                        gram_hat[0])
    style_loss /= len(target_style_representation)

    tv_loss = utils.total_variation(optimizing_img)

    total_loss = config['content_weight'] * content_loss + config[
        'style_weight'] * style_loss + config['tv_weight'] * tv_loss

    return total_loss, content_loss, style_loss, tv_loss
def train(training_config):
    writer = SummaryWriter(
    )  # (tensorboard) writer will output to ./runs/ directory by default
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare data loader
    train_loader = utils.get_training_data_loader(training_config)

    # prepare neural networks
    transformer_net = TransformerNet().train().to(device)
    perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device)

    optimizer = LBFGS(transformer_net.parameters(),
                      line_search_fn='strong_wolfe')

    # Calculate style image's Gram matrices (style representation)
    # Built over feature maps as produced by the perceptual net - VGG16
    style_img_path = os.path.join(training_config['style_images_path'],
                                  training_config['style_img_name'])
    style_img = utils.prepare_img(style_img_path,
                                  target_shape=None,
                                  device=device,
                                  batch_size=training_config['batch_size'])
    style_img_set_of_feature_maps = perceptual_loss_net(style_img)
    target_style_representation = [
        utils.gram_matrix(x) for x in style_img_set_of_feature_maps
    ]

    utils.print_header(training_config)
    # Tracking loss metrics, NST is ill-posed we can only track loss and visual appearance of the stylized images
    acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]
    ts = time.time()
    for epoch in range(training_config['num_of_epochs']):
        for batch_id, (content_batch, _) in enumerate(train_loader):
            # step1: Feed content batch through transformer net
            content_batch = content_batch.to(device)
            stylized_batch = transformer_net(content_batch)

            # step2: Feed content and stylized batch through perceptual net (VGG16)
            content_batch_set_of_feature_maps = perceptual_loss_net(
                content_batch)
            stylized_batch_set_of_feature_maps = perceptual_loss_net(
                stylized_batch)

            # step3: Calculate content representations and content loss
            target_content_representation = content_batch_set_of_feature_maps.relu2_2
            current_content_representation = stylized_batch_set_of_feature_maps.relu2_2
            content_loss = training_config['content_weight'] * torch.nn.MSELoss(
                reduction='mean')(target_content_representation,
                                  current_content_representation)

            # step4: Calculate style representation and style loss
            style_loss = 0.0
            current_style_representation = [
                utils.gram_matrix(x)
                for x in stylized_batch_set_of_feature_maps
            ]
            for gram_gt, gram_hat in zip(target_style_representation,
                                         current_style_representation):
                style_loss += torch.nn.MSELoss(reduction='mean')(gram_gt,
                                                                 gram_hat)
            style_loss /= len(target_style_representation)
            style_loss *= training_config['style_weight']

            # step5: Calculate total variation loss - enforces image smoothness
            tv_loss = training_config['tv_weight'] * utils.total_variation(
                stylized_batch)

            # step6: Combine losses and do a backprop
            total_loss = content_loss + style_loss + tv_loss
            total_loss.backward()

            def closure():
                nonlocal total_loss
                optimizer.zero_grad()
                return total_loss

            optimizer.step(closure)

            #
            # Logging and checkpoint creation
            #
            acc_content_loss += content_loss.item()
            acc_style_loss += style_loss.item()
            acc_tv_loss += tv_loss.item()

            if training_config['enable_tensorboard']:
                # log scalars
                writer.add_scalar('Loss/content-loss', content_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalar('Loss/style-loss', style_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalar('Loss/tv-loss', tv_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalars(
                    'Statistics/min-max-mean-median', {
                        'min': torch.min(stylized_batch),
                        'max': torch.max(stylized_batch),
                        'mean': torch.mean(stylized_batch),
                        'median': torch.median(stylized_batch)
                    },
                    len(train_loader) * epoch + batch_id + 1)
                # log stylized image
                if batch_id % training_config['image_log_freq'] == 0:
                    stylized = utils.post_process_image(
                        stylized_batch[0].detach().to('cpu').numpy())
                    stylized = np.moveaxis(
                        stylized, 2, 0)  # writer expects channel first image
                    writer.add_image('stylized_img', stylized,
                                     len(train_loader) * epoch + batch_id + 1)

            if training_config[
                    'console_log_freq'] is not None and batch_id % training_config[
                        'console_log_freq'] == 0:
                print(
                    f'time elapsed={(time.time() - ts) / 60:.2f}[min]|epoch={epoch + 1}|batch=[{batch_id + 1}/{len(train_loader)}]|c-loss={acc_content_loss / training_config["console_log_freq"]}|s-loss={acc_style_loss / training_config["console_log_freq"]}|tv-loss={acc_tv_loss / training_config["console_log_freq"]}|total loss={(acc_content_loss + acc_style_loss + acc_tv_loss) / training_config["console_log_freq"]}'
                )
                acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]

            if training_config['checkpoint_freq'] is not None and (
                    batch_id + 1) % training_config['checkpoint_freq'] == 0:
                training_state = utils.get_training_metadata(training_config)
                training_state["state_dict"] = transformer_net.state_dict()
                training_state["optimizer_state"] = optimizer.state_dict()
                ckpt_model_name = f"ckpt_style_{training_config['style_img_name'].split('.')[0]}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}_epoch_{epoch}_batch_{batch_id}.pth"
                torch.save(
                    training_state,
                    os.path.join(training_config['checkpoints_path'],
                                 ckpt_model_name))

    #
    # Save model with additional metadata - like which commit was used to train the model, style/content weights, etc.
    #
    training_state = utils.get_training_metadata(training_config)
    training_state["state_dict"] = transformer_net.state_dict()
    training_state["optimizer_state"] = optimizer.state_dict()
    model_name = f"style_{training_config['style_img_name'].split('.')[0]}_datapoints_{training_state['num_of_datapoints']}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}.pth"
    torch.save(
        training_state,
        os.path.join(training_config['model_binaries_path'], model_name))