def build_loss(neural_net, optimizing_img, target_representations,
               content_feature_maps_index, style_feature_maps_indices, config):
    target_content_representation = target_representations[0]
    target_style_representation = target_representations[1]

    current_set_of_feature_maps = neural_net(optimizing_img)

    current_content_representation = current_set_of_feature_maps[
        content_feature_maps_index].squeeze(axis=0)
    content_loss = torch.nn.MSELoss(reduction='mean')(
        target_content_representation, current_content_representation)

    style_loss = 0.0
    current_style_representation = [
        utils.gram_matrix(x)
        for cnt, x in enumerate(current_set_of_feature_maps)
        if cnt in style_feature_maps_indices
    ]
    for gram_gt, gram_hat in zip(target_style_representation,
                                 current_style_representation):
        style_loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0],
                                                        gram_hat[0])
    style_loss /= len(target_style_representation)

    tv_loss = utils.total_variation(optimizing_img)

    total_loss = config['content_weight'] * content_loss + config[
        'style_weight'] * style_loss + config['tv_weight'] * tv_loss

    return total_loss, content_loss, style_loss, tv_loss
Пример #2
0
    def call(self, inputs):
        inputs = inputs * 255.0
        preprocessed_input = tf.keras.applications.vgg19.preprocess_input(
            inputs)
        outputs = self.vgg(preprocessed_input)
        style_outputs, content_outputs = (
            outputs[:self.num_style_layers],
            outputs[self.num_style_layers:],
        )

        style_outputs = [
            gram_matrix(style_output) for style_output in style_outputs
        ]

        content_dict = {
            content_name: value
            for content_name, value in zip(self.content_layers,
                                           content_outputs)
        }

        style_dict = {
            style_name: value
            for style_name, value in zip(self.style_layers, style_outputs)
        }

        return {"content": content_dict, "style": style_dict}
 def closure():
     nonlocal cnt
     optimizer.zero_grad()
     loss = 0.0
     if should_reconstruct_content:
         loss = torch.nn.MSELoss(reduction='mean')(
             target_content_representation, neural_net(optimizing_img)[
                 content_feature_maps_index_name[0]].squeeze(axis=0))
     else:
         current_set_of_feature_maps = neural_net(optimizing_img)
         current_style_representation = [
             utils.gram_matrix(fmaps)
             for i, fmaps in enumerate(current_set_of_feature_maps)
             if i in style_feature_maps_indices_names[0]
         ]
         for gram_gt, gram_hat in zip(target_style_representation,
                                      current_style_representation):
             loss += (1 / len(target_style_representation)
                      ) * torch.nn.MSELoss(reduction='sum')(gram_gt[0],
                                                            gram_hat[0])
     loss.backward()
     with torch.no_grad():
         print(
             f'Iteration: {cnt}, current {"content" if should_reconstruct_content else "style"} loss={loss.item()}'
         )
         utils.save_and_maybe_display(
             optimizing_img,
             dump_path,
             config,
             cnt,
             num_of_iterations[config['optimizer']],
             should_display=False)
         cnt += 1
     return loss
Пример #4
0
    def train(self):
        total_step = len(self.data_loader)
        optimizer = Adam(self.transfer_net.parameters(), lr=self.lr)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       self.decay_epoch, 0.5)
        content_criterion = nn.MSELoss()
        stlye_criterion = nn.MSELoss()
        self.transfer_net.train()
        self.vgg.eval()

        for epoch in range(self.epoch, self.num_epoch):
            if not os.path.exists(
                    os.path.join(self.sample_dir, self.style_image_name,
                                 f"{epoch}")):
                os.makedirs(
                    os.path.join(self.sample_dir, self.style_image_name,
                                 f"{epoch}"))
            for step, image in enumerate(self.data_loader):

                optimizer.zero_grad()
                image = image.to(self.device)
                transformed_image = self.transfer_net(image)

                image_feature = self.vgg(image)
                transformed_image_feature = self.vgg(transformed_image)

                content_loss = self.content_weight * content_criterion(
                    image_feature.relu2_2, transformed_image_feature.relu2_2)

                style_loss = 0
                for ft_y, gm_s in zip(transformed_image_feature,
                                      self.gram_style):
                    gm_y = gram_matrix(ft_y)
                    style_loss += stlye_criterion(gm_y,
                                                  gm_s[:self.batch_size, :, :])
                style_loss *= self.style_weight

                total_loss = content_loss + style_loss

                total_loss.backward(retain_graph=True)
                optimizer.step()

                if step % 10 == 0:
                    print(
                        f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] "
                        f"[Style loss: {style_loss.item():.4}] [Content loss loss: {content_loss.item():.4}]"
                    )
                    if step % 100 == 0:
                        image = torch.cat((image, transformed_image), dim=2)
                        save_image(image,
                                   os.path.join(self.sample_dir,
                                                self.style_image_name,
                                                f"{epoch}", f"{step}.png"),
                                   normalize=False)

            torch.save(
                self.transfer_net.state_dict(),
                os.path.join(self.checkpoint_dir, self.style_image_name,
                             f"TransferNet_{epoch}.pth"))
            lr_scheduler.step()
    def tuning_step(optimizing_img):
        # Finds the current representation
        set_of_feature_maps = model(optimizing_img)
        if should_reconstruct_content:
            current_representation = set_of_feature_maps[
                content_feature_maps_index].squeeze(axis=0)
        else:
            current_representation = [
                utils.gram_matrix(fmaps)
                for i, fmaps in enumerate(set_of_feature_maps)
                if i in style_feature_maps_indices
            ]

        # Computes the loss between current and target representations
        loss = 0.0
        if should_reconstruct_content:
            loss = torch.nn.MSELoss(reduction='mean')(target_representation,
                                                      current_representation)
        else:
            for gram_gt, gram_hat in zip(target_representation,
                                         current_representation):
                loss += (1 / len(target_representation)) * torch.nn.MSELoss(
                    reduction='sum')(gram_gt[0], gram_hat[0])

        # Computes gradients
        loss.backward()
        # Updates parameters and zeroes gradients
        optimizer.step()
        optimizer.zero_grad()
        # Returns the loss
        return loss.item(), current_representation
Пример #6
0
    def forward(self, input, mask, output, gt):
        loss_dict = {}
        output_comp = mask * input + (1 - mask) * output

        if output.shape[1] == 3:
            feat_output_comp = self.extractor(output_comp)
            feat_output = self.extractor(output)
            feat_gt = self.extractor(gt)
        elif output.shape[1] == 1:
            feat_output_comp = self.extractor(torch.cat([output_comp] * 3, 1))
            feat_output = self.extractor(torch.cat([output] * 3, 1))
            feat_gt = self.extractor(torch.cat([gt] * 3, 1))
        else:
            raise ValueError('only gray an')

        loss_dict['prc'] = 0.0
        for i in range(3):
            loss_dict['prc'] += self.l1(feat_output[i], feat_gt[i])
            loss_dict['prc'] += self.l1(feat_output_comp[i], feat_gt[i])

        if self.kbe_only:
            loss_dict['color'] = self.l1(output, gt)
        else:
            loss_dict['hole'] = self.l1((1 - mask) * output, (1 - mask) * gt)
            loss_dict['valid'] = self.l1(mask * output, mask * gt)

            loss_dict['style'] = 0.0
            for i in range(3):
                loss_dict['style'] += self.l1(gram_matrix(feat_output[i]),
                                              gram_matrix(feat_gt[i]))
                loss_dict['style'] += self.l1(gram_matrix(feat_output_comp[i]),
                                              gram_matrix(feat_gt[i]))

            loss_dict['tv'] = total_variation_loss(output_comp)

        return loss_dict
Пример #7
0
    def load_feature_style(self):
        if not os.path.exists(self.style_dir):
            os.makedirs(self.style_dir)
        if not os.listdir(self.style_dir):
            raise Exception(f"[!] No image for style transfer")

        image_name = glob(
            os.path.join(self.style_dir, f"{self.style_image_name}.*"))
        if not image_name:
            raise Exception(
                f"[!] No image for {self.style_image_name} transfer")

        image = load_image(image_name[0], size=self.image_size)
        image = transforms.Compose([
            transforms.CenterCrop(min(image.size[0], image.size[1])),
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
        ])(image)
        image = image.repeat(self.batch_size, 1, 1, 1)
        image = image.to(self.device)
        style_image = self.vgg(image)
        self.gram_style = [gram_matrix(y) for y in style_image]
def neural_style_transfer(config):
    content_img_path = os.path.join(config['content_images_dir'],
                                    config['content_img_name'])
    style_img_path = os.path.join(config['style_images_dir'],
                                  config['style_img_name'])

    out_dir_name = 'combined_' + os.path.split(content_img_path)[1].split(
        '.')[0] + '_' + os.path.split(style_img_path)[1].split('.')[0]
    dump_path = os.path.join(config['output_img_dir'], out_dir_name)
    os.makedirs(dump_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    content_img = utils.prepare_img(content_img_path, config['height'], device)
    style_img = utils.prepare_img(style_img_path, config['height'], device)

    if config['init_method'] == 'random':
        # white_noise_img = np.random.uniform(-90., 90., content_img.shape).astype(np.float32)
        gaussian_noise_img = np.random.normal(loc=0,
                                              scale=90.,
                                              size=content_img.shape).astype(
                                                  np.float32)
        init_img = torch.from_numpy(gaussian_noise_img).float().to(device)
    elif config['init_method'] == 'content':
        init_img = content_img
    else:
        # init image has same dimension as content image - this is a hard constraint
        # feature maps need to be of same size for content image and init image
        style_img_resized = utils.prepare_img(
            style_img_path, np.asarray(content_img.shape[2:]), device)
        init_img = style_img_resized

    # we are tuning optimizing_img's pixels! (that's why requires_grad=True)
    optimizing_img = Variable(init_img, requires_grad=True)

    neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model(
        config['model'], device)
    print(f'Using {config["model"]} in the optimization procedure.')

    content_img_set_of_feature_maps = neural_net(content_img)
    style_img_set_of_feature_maps = neural_net(style_img)

    target_content_representation = content_img_set_of_feature_maps[
        content_feature_maps_index_name[0]].squeeze(axis=0)
    target_style_representation = [
        utils.gram_matrix(x)
        for cnt, x in enumerate(style_img_set_of_feature_maps)
        if cnt in style_feature_maps_indices_names[0]
    ]
    target_representations = [
        target_content_representation, target_style_representation
    ]

    # magic numbers in general are a big no no - some things in this code are left like this by design to avoid clutter
    num_of_iterations = {
        "lbfgs": 1000,
        "adam": 3000,
    }

    #
    # Start of optimization procedure
    #
    if config['optimizer'] == 'adam':
        optimizer = Adam((optimizing_img, ), lr=1e1)
        tuning_step = make_tuning_step(neural_net, optimizer,
                                       target_representations,
                                       content_feature_maps_index_name[0],
                                       style_feature_maps_indices_names[0],
                                       config)
        for cnt in range(num_of_iterations[config['optimizer']]):
            total_loss, content_loss, style_loss, tv_loss = tuning_step(
                optimizing_img)
            with torch.no_grad():
                print(
                    f'Adam | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
    elif config['optimizer'] == 'lbfgs':
        # line_search_fn does not seem to have significant impact on result
        optimizer = LBFGS((optimizing_img, ),
                          max_iter=num_of_iterations['lbfgs'],
                          line_search_fn='strong_wolfe')
        cnt = 0

        def closure():
            nonlocal cnt
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            total_loss, content_loss, style_loss, tv_loss = build_loss(
                neural_net, optimizing_img, target_representations,
                content_feature_maps_index_name[0],
                style_feature_maps_indices_names[0], config)
            if total_loss.requires_grad:
                total_loss.backward()
            with torch.no_grad():
                print(
                    f'L-BFGS | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)

            cnt += 1
            return total_loss

        optimizer.step(closure)

    return dump_path
Пример #9
0
def train(**kwargs):
    opt._parse(kwargs)

    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    vis = Visualizer(opt.env)

    # Data loading
    transfroms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # style transformer network
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    transformer.to(device)

    # Vgg16 for Perceptual Loss
    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # Optimizer: use Adam
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # Get style image
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    # print("style.shape: ", style.shape)

    # gram matrix for style image
    with t.no_grad():
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]

    # Loss meter
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # Train
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # Loss smooth for visualization
            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # visualization
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # denorm input/output, since we have applied (utils.normalize_batch)
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        # save checkpoint
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
def reconstruct_image_from_representation(config):
    should_reconstruct_content = config['should_reconstruct_content']
    should_visualize_representation = config['should_visualize_representation']
    dump_path = os.path.join(config['output_img_dir'],
                             ('c' if should_reconstruct_content else 's') +
                             '_reconstruction_' + config['optimizer'])
    dump_path = os.path.join(
        dump_path, config['content_img_name'].split('.')[0] if
        should_reconstruct_content else config['style_img_name'].split('.')[0])
    os.makedirs(dump_path, exist_ok=True)

    content_img_path = os.path.join(config['content_images_dir'],
                                    config['content_img_name'])
    style_img_path = os.path.join(config['style_images_dir'],
                                  config['style_img_name'])
    img_path = content_img_path if should_reconstruct_content else style_img_path

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    img = utils.prepare_img(img_path, config['height'], device)

    gaussian_noise_img = np.random.normal(loc=0, scale=90.,
                                          size=img.shape).astype(np.float32)
    white_noise_img = np.random.uniform(-90., 90.,
                                        img.shape).astype(np.float32)
    init_img = torch.from_numpy(white_noise_img).float().to(device)
    optimizing_img = Variable(init_img, requires_grad=True)

    # indices pick relevant feature maps (say conv4_1, relu1_1, etc.)
    neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model(
        config['model'], device)

    # don't want to expose everything that's not crucial so some things are hardcoded
    num_of_iterations = {'adam': 3000, 'lbfgs': 350}

    set_of_feature_maps = neural_net(img)

    #
    # Visualize feature maps and Gram matrices (depending whether you're reconstructing content or style img)
    #
    if should_reconstruct_content:
        target_content_representation = set_of_feature_maps[
            content_feature_maps_index_name[0]].squeeze(axis=0)
        if should_visualize_representation:
            num_of_feature_maps = target_content_representation.size()[0]
            print(f'Number of feature maps: {num_of_feature_maps}')
            for i in range(num_of_feature_maps):
                feature_map = target_content_representation[i].to(
                    'cpu').numpy()
                feature_map = np.uint8(utils.get_uint8_range(feature_map))
                plt.imshow(feature_map)
                plt.title(
                    f'Feature map {i+1}/{num_of_feature_maps} from layer {content_feature_maps_index_name[1]} (model={config["model"]}) for {config["content_img_name"]} image.'
                )
                plt.show()
                filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
                utils.save_image(feature_map,
                                 os.path.join(dump_path, filename))
    else:
        target_style_representation = [
            utils.gram_matrix(fmaps)
            for i, fmaps in enumerate(set_of_feature_maps)
            if i in style_feature_maps_indices_names[0]
        ]
        if should_visualize_representation:
            num_of_gram_matrices = len(target_style_representation)
            print(f'Number of Gram matrices: {num_of_gram_matrices}')
            for i in range(num_of_gram_matrices):
                Gram_matrix = target_style_representation[i].squeeze(
                    axis=0).to('cpu').numpy()
                Gram_matrix = np.uint8(utils.get_uint8_range(Gram_matrix))
                plt.imshow(Gram_matrix)
                plt.title(
                    f'Gram matrix from layer {style_feature_maps_indices_names[1][i]} (model={config["model"]}) for {config["style_img_name"]} image.'
                )
                plt.show()
                filename = f'gram_{config["model"]}_{style_feature_maps_indices_names[1][i]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
                utils.save_image(Gram_matrix,
                                 os.path.join(dump_path, filename))

    #
    # Start of optimization procedure
    #
    if config['optimizer'] == 'adam':
        optimizer = Adam((optimizing_img, ))
        target_representation = target_content_representation if should_reconstruct_content else target_style_representation
        tuning_step = make_tuning_step(neural_net, optimizer,
                                       target_representation,
                                       should_reconstruct_content,
                                       content_feature_maps_index_name[0],
                                       style_feature_maps_indices_names[0])
        for it in range(num_of_iterations[config['optimizer']]):
            loss, _ = tuning_step(optimizing_img)
            with torch.no_grad():
                print(
                    f'Iteration: {it}, current {"content" if should_reconstruct_content else "style"} loss={loss:10.8f}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    it,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
    elif config['optimizer'] == 'lbfgs':
        cnt = 0

        # closure is a function required by L-BFGS optimizer
        def closure():
            nonlocal cnt
            optimizer.zero_grad()
            loss = 0.0
            if should_reconstruct_content:
                loss = torch.nn.MSELoss(reduction='mean')(
                    target_content_representation, neural_net(optimizing_img)[
                        content_feature_maps_index_name[0]].squeeze(axis=0))
            else:
                current_set_of_feature_maps = neural_net(optimizing_img)
                current_style_representation = [
                    utils.gram_matrix(fmaps)
                    for i, fmaps in enumerate(current_set_of_feature_maps)
                    if i in style_feature_maps_indices_names[0]
                ]
                for gram_gt, gram_hat in zip(target_style_representation,
                                             current_style_representation):
                    loss += (1 / len(target_style_representation)
                             ) * torch.nn.MSELoss(reduction='sum')(gram_gt[0],
                                                                   gram_hat[0])
            loss.backward()
            with torch.no_grad():
                print(
                    f'Iteration: {cnt}, current {"content" if should_reconstruct_content else "style"} loss={loss.item()}'
                )
                utils.save_and_maybe_display(
                    optimizing_img,
                    dump_path,
                    config,
                    cnt,
                    num_of_iterations[config['optimizer']],
                    should_display=False)
                cnt += 1
            return loss

        optimizer = torch.optim.LBFGS(
            (optimizing_img, ),
            max_iter=num_of_iterations[config['optimizer']],
            line_search_fn='strong_wolfe')
        optimizer.step(closure)

    return dump_path
Пример #11
0
def train(training_config):
    writer = SummaryWriter(
    )  # (tensorboard) writer will output to ./runs/ directory by default
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare data loader
    train_loader = utils.get_training_data_loader(training_config)

    # prepare neural networks
    transformer_net = TransformerNet().train().to(device)
    perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device)

    optimizer = LBFGS(transformer_net.parameters(),
                      line_search_fn='strong_wolfe')

    # Calculate style image's Gram matrices (style representation)
    # Built over feature maps as produced by the perceptual net - VGG16
    style_img_path = os.path.join(training_config['style_images_path'],
                                  training_config['style_img_name'])
    style_img = utils.prepare_img(style_img_path,
                                  target_shape=None,
                                  device=device,
                                  batch_size=training_config['batch_size'])
    style_img_set_of_feature_maps = perceptual_loss_net(style_img)
    target_style_representation = [
        utils.gram_matrix(x) for x in style_img_set_of_feature_maps
    ]

    utils.print_header(training_config)
    # Tracking loss metrics, NST is ill-posed we can only track loss and visual appearance of the stylized images
    acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]
    ts = time.time()
    for epoch in range(training_config['num_of_epochs']):
        for batch_id, (content_batch, _) in enumerate(train_loader):
            # step1: Feed content batch through transformer net
            content_batch = content_batch.to(device)
            stylized_batch = transformer_net(content_batch)

            # step2: Feed content and stylized batch through perceptual net (VGG16)
            content_batch_set_of_feature_maps = perceptual_loss_net(
                content_batch)
            stylized_batch_set_of_feature_maps = perceptual_loss_net(
                stylized_batch)

            # step3: Calculate content representations and content loss
            target_content_representation = content_batch_set_of_feature_maps.relu2_2
            current_content_representation = stylized_batch_set_of_feature_maps.relu2_2
            content_loss = training_config['content_weight'] * torch.nn.MSELoss(
                reduction='mean')(target_content_representation,
                                  current_content_representation)

            # step4: Calculate style representation and style loss
            style_loss = 0.0
            current_style_representation = [
                utils.gram_matrix(x)
                for x in stylized_batch_set_of_feature_maps
            ]
            for gram_gt, gram_hat in zip(target_style_representation,
                                         current_style_representation):
                style_loss += torch.nn.MSELoss(reduction='mean')(gram_gt,
                                                                 gram_hat)
            style_loss /= len(target_style_representation)
            style_loss *= training_config['style_weight']

            # step5: Calculate total variation loss - enforces image smoothness
            tv_loss = training_config['tv_weight'] * utils.total_variation(
                stylized_batch)

            # step6: Combine losses and do a backprop
            total_loss = content_loss + style_loss + tv_loss
            total_loss.backward()

            def closure():
                nonlocal total_loss
                optimizer.zero_grad()
                return total_loss

            optimizer.step(closure)

            #
            # Logging and checkpoint creation
            #
            acc_content_loss += content_loss.item()
            acc_style_loss += style_loss.item()
            acc_tv_loss += tv_loss.item()

            if training_config['enable_tensorboard']:
                # log scalars
                writer.add_scalar('Loss/content-loss', content_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalar('Loss/style-loss', style_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalar('Loss/tv-loss', tv_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalars(
                    'Statistics/min-max-mean-median', {
                        'min': torch.min(stylized_batch),
                        'max': torch.max(stylized_batch),
                        'mean': torch.mean(stylized_batch),
                        'median': torch.median(stylized_batch)
                    },
                    len(train_loader) * epoch + batch_id + 1)
                # log stylized image
                if batch_id % training_config['image_log_freq'] == 0:
                    stylized = utils.post_process_image(
                        stylized_batch[0].detach().to('cpu').numpy())
                    stylized = np.moveaxis(
                        stylized, 2, 0)  # writer expects channel first image
                    writer.add_image('stylized_img', stylized,
                                     len(train_loader) * epoch + batch_id + 1)

            if training_config[
                    'console_log_freq'] is not None and batch_id % training_config[
                        'console_log_freq'] == 0:
                print(
                    f'time elapsed={(time.time() - ts) / 60:.2f}[min]|epoch={epoch + 1}|batch=[{batch_id + 1}/{len(train_loader)}]|c-loss={acc_content_loss / training_config["console_log_freq"]}|s-loss={acc_style_loss / training_config["console_log_freq"]}|tv-loss={acc_tv_loss / training_config["console_log_freq"]}|total loss={(acc_content_loss + acc_style_loss + acc_tv_loss) / training_config["console_log_freq"]}'
                )
                acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]

            if training_config['checkpoint_freq'] is not None and (
                    batch_id + 1) % training_config['checkpoint_freq'] == 0:
                training_state = utils.get_training_metadata(training_config)
                training_state["state_dict"] = transformer_net.state_dict()
                training_state["optimizer_state"] = optimizer.state_dict()
                ckpt_model_name = f"ckpt_style_{training_config['style_img_name'].split('.')[0]}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}_epoch_{epoch}_batch_{batch_id}.pth"
                torch.save(
                    training_state,
                    os.path.join(training_config['checkpoints_path'],
                                 ckpt_model_name))

    #
    # Save model with additional metadata - like which commit was used to train the model, style/content weights, etc.
    #
    training_state = utils.get_training_metadata(training_config)
    training_state["state_dict"] = transformer_net.state_dict()
    training_state["optimizer_state"] = optimizer.state_dict()
    model_name = f"style_{training_config['style_img_name'].split('.')[0]}_datapoints_{training_state['num_of_datapoints']}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}.pth"
    torch.save(
        training_state,
        os.path.join(training_config['model_binaries_path'], model_name))