def build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index, style_feature_maps_indices, config): target_content_representation = target_representations[0] target_style_representation = target_representations[1] current_set_of_feature_maps = neural_net(optimizing_img) current_content_representation = current_set_of_feature_maps[ content_feature_maps_index].squeeze(axis=0) content_loss = torch.nn.MSELoss(reduction='mean')( target_content_representation, current_content_representation) style_loss = 0.0 current_style_representation = [ utils.gram_matrix(x) for cnt, x in enumerate(current_set_of_feature_maps) if cnt in style_feature_maps_indices ] for gram_gt, gram_hat in zip(target_style_representation, current_style_representation): style_loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) style_loss /= len(target_style_representation) tv_loss = utils.total_variation(optimizing_img) total_loss = config['content_weight'] * content_loss + config[ 'style_weight'] * style_loss + config['tv_weight'] * tv_loss return total_loss, content_loss, style_loss, tv_loss
def call(self, inputs): inputs = inputs * 255.0 preprocessed_input = tf.keras.applications.vgg19.preprocess_input( inputs) outputs = self.vgg(preprocessed_input) style_outputs, content_outputs = ( outputs[:self.num_style_layers], outputs[self.num_style_layers:], ) style_outputs = [ gram_matrix(style_output) for style_output in style_outputs ] content_dict = { content_name: value for content_name, value in zip(self.content_layers, content_outputs) } style_dict = { style_name: value for style_name, value in zip(self.style_layers, style_outputs) } return {"content": content_dict, "style": style_dict}
def closure(): nonlocal cnt optimizer.zero_grad() loss = 0.0 if should_reconstruct_content: loss = torch.nn.MSELoss(reduction='mean')( target_content_representation, neural_net(optimizing_img)[ content_feature_maps_index_name[0]].squeeze(axis=0)) else: current_set_of_feature_maps = neural_net(optimizing_img) current_style_representation = [ utils.gram_matrix(fmaps) for i, fmaps in enumerate(current_set_of_feature_maps) if i in style_feature_maps_indices_names[0] ] for gram_gt, gram_hat in zip(target_style_representation, current_style_representation): loss += (1 / len(target_style_representation) ) * torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) loss.backward() with torch.no_grad(): print( f'Iteration: {cnt}, current {"content" if should_reconstruct_content else "style"} loss={loss.item()}' ) utils.save_and_maybe_display( optimizing_img, dump_path, config, cnt, num_of_iterations[config['optimizer']], should_display=False) cnt += 1 return loss
def train(self): total_step = len(self.data_loader) optimizer = Adam(self.transfer_net.parameters(), lr=self.lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.decay_epoch, 0.5) content_criterion = nn.MSELoss() stlye_criterion = nn.MSELoss() self.transfer_net.train() self.vgg.eval() for epoch in range(self.epoch, self.num_epoch): if not os.path.exists( os.path.join(self.sample_dir, self.style_image_name, f"{epoch}")): os.makedirs( os.path.join(self.sample_dir, self.style_image_name, f"{epoch}")) for step, image in enumerate(self.data_loader): optimizer.zero_grad() image = image.to(self.device) transformed_image = self.transfer_net(image) image_feature = self.vgg(image) transformed_image_feature = self.vgg(transformed_image) content_loss = self.content_weight * content_criterion( image_feature.relu2_2, transformed_image_feature.relu2_2) style_loss = 0 for ft_y, gm_s in zip(transformed_image_feature, self.gram_style): gm_y = gram_matrix(ft_y) style_loss += stlye_criterion(gm_y, gm_s[:self.batch_size, :, :]) style_loss *= self.style_weight total_loss = content_loss + style_loss total_loss.backward(retain_graph=True) optimizer.step() if step % 10 == 0: print( f"[Epoch {epoch}/{self.num_epoch}] [Batch {step}/{total_step}] " f"[Style loss: {style_loss.item():.4}] [Content loss loss: {content_loss.item():.4}]" ) if step % 100 == 0: image = torch.cat((image, transformed_image), dim=2) save_image(image, os.path.join(self.sample_dir, self.style_image_name, f"{epoch}", f"{step}.png"), normalize=False) torch.save( self.transfer_net.state_dict(), os.path.join(self.checkpoint_dir, self.style_image_name, f"TransferNet_{epoch}.pth")) lr_scheduler.step()
def tuning_step(optimizing_img): # Finds the current representation set_of_feature_maps = model(optimizing_img) if should_reconstruct_content: current_representation = set_of_feature_maps[ content_feature_maps_index].squeeze(axis=0) else: current_representation = [ utils.gram_matrix(fmaps) for i, fmaps in enumerate(set_of_feature_maps) if i in style_feature_maps_indices ] # Computes the loss between current and target representations loss = 0.0 if should_reconstruct_content: loss = torch.nn.MSELoss(reduction='mean')(target_representation, current_representation) else: for gram_gt, gram_hat in zip(target_representation, current_representation): loss += (1 / len(target_representation)) * torch.nn.MSELoss( reduction='sum')(gram_gt[0], gram_hat[0]) # Computes gradients loss.backward() # Updates parameters and zeroes gradients optimizer.step() optimizer.zero_grad() # Returns the loss return loss.item(), current_representation
def forward(self, input, mask, output, gt): loss_dict = {} output_comp = mask * input + (1 - mask) * output if output.shape[1] == 3: feat_output_comp = self.extractor(output_comp) feat_output = self.extractor(output) feat_gt = self.extractor(gt) elif output.shape[1] == 1: feat_output_comp = self.extractor(torch.cat([output_comp] * 3, 1)) feat_output = self.extractor(torch.cat([output] * 3, 1)) feat_gt = self.extractor(torch.cat([gt] * 3, 1)) else: raise ValueError('only gray an') loss_dict['prc'] = 0.0 for i in range(3): loss_dict['prc'] += self.l1(feat_output[i], feat_gt[i]) loss_dict['prc'] += self.l1(feat_output_comp[i], feat_gt[i]) if self.kbe_only: loss_dict['color'] = self.l1(output, gt) else: loss_dict['hole'] = self.l1((1 - mask) * output, (1 - mask) * gt) loss_dict['valid'] = self.l1(mask * output, mask * gt) loss_dict['style'] = 0.0 for i in range(3): loss_dict['style'] += self.l1(gram_matrix(feat_output[i]), gram_matrix(feat_gt[i])) loss_dict['style'] += self.l1(gram_matrix(feat_output_comp[i]), gram_matrix(feat_gt[i])) loss_dict['tv'] = total_variation_loss(output_comp) return loss_dict
def load_feature_style(self): if not os.path.exists(self.style_dir): os.makedirs(self.style_dir) if not os.listdir(self.style_dir): raise Exception(f"[!] No image for style transfer") image_name = glob( os.path.join(self.style_dir, f"{self.style_image_name}.*")) if not image_name: raise Exception( f"[!] No image for {self.style_image_name} transfer") image = load_image(image_name[0], size=self.image_size) image = transforms.Compose([ transforms.CenterCrop(min(image.size[0], image.size[1])), transforms.Resize(self.image_size), transforms.ToTensor(), ])(image) image = image.repeat(self.batch_size, 1, 1, 1) image = image.to(self.device) style_image = self.vgg(image) self.gram_style = [gram_matrix(y) for y in style_image]
def neural_style_transfer(config): content_img_path = os.path.join(config['content_images_dir'], config['content_img_name']) style_img_path = os.path.join(config['style_images_dir'], config['style_img_name']) out_dir_name = 'combined_' + os.path.split(content_img_path)[1].split( '.')[0] + '_' + os.path.split(style_img_path)[1].split('.')[0] dump_path = os.path.join(config['output_img_dir'], out_dir_name) os.makedirs(dump_path, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") content_img = utils.prepare_img(content_img_path, config['height'], device) style_img = utils.prepare_img(style_img_path, config['height'], device) if config['init_method'] == 'random': # white_noise_img = np.random.uniform(-90., 90., content_img.shape).astype(np.float32) gaussian_noise_img = np.random.normal(loc=0, scale=90., size=content_img.shape).astype( np.float32) init_img = torch.from_numpy(gaussian_noise_img).float().to(device) elif config['init_method'] == 'content': init_img = content_img else: # init image has same dimension as content image - this is a hard constraint # feature maps need to be of same size for content image and init image style_img_resized = utils.prepare_img( style_img_path, np.asarray(content_img.shape[2:]), device) init_img = style_img_resized # we are tuning optimizing_img's pixels! (that's why requires_grad=True) optimizing_img = Variable(init_img, requires_grad=True) neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model( config['model'], device) print(f'Using {config["model"]} in the optimization procedure.') content_img_set_of_feature_maps = neural_net(content_img) style_img_set_of_feature_maps = neural_net(style_img) target_content_representation = content_img_set_of_feature_maps[ content_feature_maps_index_name[0]].squeeze(axis=0) target_style_representation = [ utils.gram_matrix(x) for cnt, x in enumerate(style_img_set_of_feature_maps) if cnt in style_feature_maps_indices_names[0] ] target_representations = [ target_content_representation, target_style_representation ] # magic numbers in general are a big no no - some things in this code are left like this by design to avoid clutter num_of_iterations = { "lbfgs": 1000, "adam": 3000, } # # Start of optimization procedure # if config['optimizer'] == 'adam': optimizer = Adam((optimizing_img, ), lr=1e1) tuning_step = make_tuning_step(neural_net, optimizer, target_representations, content_feature_maps_index_name[0], style_feature_maps_indices_names[0], config) for cnt in range(num_of_iterations[config['optimizer']]): total_loss, content_loss, style_loss, tv_loss = tuning_step( optimizing_img) with torch.no_grad(): print( f'Adam | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}' ) utils.save_and_maybe_display( optimizing_img, dump_path, config, cnt, num_of_iterations[config['optimizer']], should_display=False) elif config['optimizer'] == 'lbfgs': # line_search_fn does not seem to have significant impact on result optimizer = LBFGS((optimizing_img, ), max_iter=num_of_iterations['lbfgs'], line_search_fn='strong_wolfe') cnt = 0 def closure(): nonlocal cnt if torch.is_grad_enabled(): optimizer.zero_grad() total_loss, content_loss, style_loss, tv_loss = build_loss( neural_net, optimizing_img, target_representations, content_feature_maps_index_name[0], style_feature_maps_indices_names[0], config) if total_loss.requires_grad: total_loss.backward() with torch.no_grad(): print( f'L-BFGS | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}' ) utils.save_and_maybe_display( optimizing_img, dump_path, config, cnt, num_of_iterations[config['optimizer']], should_display=False) cnt += 1 return total_loss optimizer.step(closure) return dump_path
def train(**kwargs): opt._parse(kwargs) device = t.device('cuda') if opt.use_gpu else t.device('cpu') vis = Visualizer(opt.env) # Data loading transfroms = tv.transforms.Compose([ tv.transforms.Resize(opt.image_size), tv.transforms.CenterCrop(opt.image_size), tv.transforms.ToTensor(), tv.transforms.Lambda(lambda x: x * 255) ]) dataset = tv.datasets.ImageFolder(opt.data_root, transfroms) dataloader = data.DataLoader(dataset, opt.batch_size) # style transformer network transformer = TransformerNet() if opt.model_path: transformer.load_state_dict( t.load(opt.model_path, map_location=lambda _s, _: _s)) transformer.to(device) # Vgg16 for Perceptual Loss vgg = Vgg16().eval() vgg.to(device) for param in vgg.parameters(): param.requires_grad = False # Optimizer: use Adam optimizer = t.optim.Adam(transformer.parameters(), opt.lr) # Get style image style = utils.get_style_data(opt.style_path) vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1)) style = style.to(device) # print("style.shape: ", style.shape) # gram matrix for style image with t.no_grad(): features_style = vgg(style) gram_style = [utils.gram_matrix(y) for y in features_style] # Loss meter style_meter = tnt.meter.AverageValueMeter() content_meter = tnt.meter.AverageValueMeter() for epoch in range(opt.epoches): content_meter.reset() style_meter.reset() for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)): # Train optimizer.zero_grad() x = x.to(device) y = transformer(x) y = utils.normalize_batch(y) x = utils.normalize_batch(x) features_y = vgg(y) features_x = vgg(x) # content loss content_loss = opt.content_weight * F.mse_loss( features_y.relu2_2, features_x.relu2_2) # style loss style_loss = 0. for ft_y, gm_s in zip(features_y, gram_style): gram_y = utils.gram_matrix(ft_y) style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y)) style_loss *= opt.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() # Loss smooth for visualization content_meter.add(content_loss.item()) style_meter.add(style_loss.item()) if (ii + 1) % opt.plot_every == 0: if os.path.exists(opt.debug_file): ipdb.set_trace() # visualization vis.plot('content_loss', content_meter.value()[0]) vis.plot('style_loss', style_meter.value()[0]) # denorm input/output, since we have applied (utils.normalize_batch) vis.img('output', (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1)) vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1)) # save checkpoint vis.save([opt.env]) t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
def reconstruct_image_from_representation(config): should_reconstruct_content = config['should_reconstruct_content'] should_visualize_representation = config['should_visualize_representation'] dump_path = os.path.join(config['output_img_dir'], ('c' if should_reconstruct_content else 's') + '_reconstruction_' + config['optimizer']) dump_path = os.path.join( dump_path, config['content_img_name'].split('.')[0] if should_reconstruct_content else config['style_img_name'].split('.')[0]) os.makedirs(dump_path, exist_ok=True) content_img_path = os.path.join(config['content_images_dir'], config['content_img_name']) style_img_path = os.path.join(config['style_images_dir'], config['style_img_name']) img_path = content_img_path if should_reconstruct_content else style_img_path device = torch.device("cuda" if torch.cuda.is_available() else "cpu") img = utils.prepare_img(img_path, config['height'], device) gaussian_noise_img = np.random.normal(loc=0, scale=90., size=img.shape).astype(np.float32) white_noise_img = np.random.uniform(-90., 90., img.shape).astype(np.float32) init_img = torch.from_numpy(white_noise_img).float().to(device) optimizing_img = Variable(init_img, requires_grad=True) # indices pick relevant feature maps (say conv4_1, relu1_1, etc.) neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model( config['model'], device) # don't want to expose everything that's not crucial so some things are hardcoded num_of_iterations = {'adam': 3000, 'lbfgs': 350} set_of_feature_maps = neural_net(img) # # Visualize feature maps and Gram matrices (depending whether you're reconstructing content or style img) # if should_reconstruct_content: target_content_representation = set_of_feature_maps[ content_feature_maps_index_name[0]].squeeze(axis=0) if should_visualize_representation: num_of_feature_maps = target_content_representation.size()[0] print(f'Number of feature maps: {num_of_feature_maps}') for i in range(num_of_feature_maps): feature_map = target_content_representation[i].to( 'cpu').numpy() feature_map = np.uint8(utils.get_uint8_range(feature_map)) plt.imshow(feature_map) plt.title( f'Feature map {i+1}/{num_of_feature_maps} from layer {content_feature_maps_index_name[1]} (model={config["model"]}) for {config["content_img_name"]} image.' ) plt.show() filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}' utils.save_image(feature_map, os.path.join(dump_path, filename)) else: target_style_representation = [ utils.gram_matrix(fmaps) for i, fmaps in enumerate(set_of_feature_maps) if i in style_feature_maps_indices_names[0] ] if should_visualize_representation: num_of_gram_matrices = len(target_style_representation) print(f'Number of Gram matrices: {num_of_gram_matrices}') for i in range(num_of_gram_matrices): Gram_matrix = target_style_representation[i].squeeze( axis=0).to('cpu').numpy() Gram_matrix = np.uint8(utils.get_uint8_range(Gram_matrix)) plt.imshow(Gram_matrix) plt.title( f'Gram matrix from layer {style_feature_maps_indices_names[1][i]} (model={config["model"]}) for {config["style_img_name"]} image.' ) plt.show() filename = f'gram_{config["model"]}_{style_feature_maps_indices_names[1][i]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}' utils.save_image(Gram_matrix, os.path.join(dump_path, filename)) # # Start of optimization procedure # if config['optimizer'] == 'adam': optimizer = Adam((optimizing_img, )) target_representation = target_content_representation if should_reconstruct_content else target_style_representation tuning_step = make_tuning_step(neural_net, optimizer, target_representation, should_reconstruct_content, content_feature_maps_index_name[0], style_feature_maps_indices_names[0]) for it in range(num_of_iterations[config['optimizer']]): loss, _ = tuning_step(optimizing_img) with torch.no_grad(): print( f'Iteration: {it}, current {"content" if should_reconstruct_content else "style"} loss={loss:10.8f}' ) utils.save_and_maybe_display( optimizing_img, dump_path, config, it, num_of_iterations[config['optimizer']], should_display=False) elif config['optimizer'] == 'lbfgs': cnt = 0 # closure is a function required by L-BFGS optimizer def closure(): nonlocal cnt optimizer.zero_grad() loss = 0.0 if should_reconstruct_content: loss = torch.nn.MSELoss(reduction='mean')( target_content_representation, neural_net(optimizing_img)[ content_feature_maps_index_name[0]].squeeze(axis=0)) else: current_set_of_feature_maps = neural_net(optimizing_img) current_style_representation = [ utils.gram_matrix(fmaps) for i, fmaps in enumerate(current_set_of_feature_maps) if i in style_feature_maps_indices_names[0] ] for gram_gt, gram_hat in zip(target_style_representation, current_style_representation): loss += (1 / len(target_style_representation) ) * torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) loss.backward() with torch.no_grad(): print( f'Iteration: {cnt}, current {"content" if should_reconstruct_content else "style"} loss={loss.item()}' ) utils.save_and_maybe_display( optimizing_img, dump_path, config, cnt, num_of_iterations[config['optimizer']], should_display=False) cnt += 1 return loss optimizer = torch.optim.LBFGS( (optimizing_img, ), max_iter=num_of_iterations[config['optimizer']], line_search_fn='strong_wolfe') optimizer.step(closure) return dump_path
def train(training_config): writer = SummaryWriter( ) # (tensorboard) writer will output to ./runs/ directory by default device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # prepare data loader train_loader = utils.get_training_data_loader(training_config) # prepare neural networks transformer_net = TransformerNet().train().to(device) perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device) optimizer = LBFGS(transformer_net.parameters(), line_search_fn='strong_wolfe') # Calculate style image's Gram matrices (style representation) # Built over feature maps as produced by the perceptual net - VGG16 style_img_path = os.path.join(training_config['style_images_path'], training_config['style_img_name']) style_img = utils.prepare_img(style_img_path, target_shape=None, device=device, batch_size=training_config['batch_size']) style_img_set_of_feature_maps = perceptual_loss_net(style_img) target_style_representation = [ utils.gram_matrix(x) for x in style_img_set_of_feature_maps ] utils.print_header(training_config) # Tracking loss metrics, NST is ill-posed we can only track loss and visual appearance of the stylized images acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.] ts = time.time() for epoch in range(training_config['num_of_epochs']): for batch_id, (content_batch, _) in enumerate(train_loader): # step1: Feed content batch through transformer net content_batch = content_batch.to(device) stylized_batch = transformer_net(content_batch) # step2: Feed content and stylized batch through perceptual net (VGG16) content_batch_set_of_feature_maps = perceptual_loss_net( content_batch) stylized_batch_set_of_feature_maps = perceptual_loss_net( stylized_batch) # step3: Calculate content representations and content loss target_content_representation = content_batch_set_of_feature_maps.relu2_2 current_content_representation = stylized_batch_set_of_feature_maps.relu2_2 content_loss = training_config['content_weight'] * torch.nn.MSELoss( reduction='mean')(target_content_representation, current_content_representation) # step4: Calculate style representation and style loss style_loss = 0.0 current_style_representation = [ utils.gram_matrix(x) for x in stylized_batch_set_of_feature_maps ] for gram_gt, gram_hat in zip(target_style_representation, current_style_representation): style_loss += torch.nn.MSELoss(reduction='mean')(gram_gt, gram_hat) style_loss /= len(target_style_representation) style_loss *= training_config['style_weight'] # step5: Calculate total variation loss - enforces image smoothness tv_loss = training_config['tv_weight'] * utils.total_variation( stylized_batch) # step6: Combine losses and do a backprop total_loss = content_loss + style_loss + tv_loss total_loss.backward() def closure(): nonlocal total_loss optimizer.zero_grad() return total_loss optimizer.step(closure) # # Logging and checkpoint creation # acc_content_loss += content_loss.item() acc_style_loss += style_loss.item() acc_tv_loss += tv_loss.item() if training_config['enable_tensorboard']: # log scalars writer.add_scalar('Loss/content-loss', content_loss.item(), len(train_loader) * epoch + batch_id + 1) writer.add_scalar('Loss/style-loss', style_loss.item(), len(train_loader) * epoch + batch_id + 1) writer.add_scalar('Loss/tv-loss', tv_loss.item(), len(train_loader) * epoch + batch_id + 1) writer.add_scalars( 'Statistics/min-max-mean-median', { 'min': torch.min(stylized_batch), 'max': torch.max(stylized_batch), 'mean': torch.mean(stylized_batch), 'median': torch.median(stylized_batch) }, len(train_loader) * epoch + batch_id + 1) # log stylized image if batch_id % training_config['image_log_freq'] == 0: stylized = utils.post_process_image( stylized_batch[0].detach().to('cpu').numpy()) stylized = np.moveaxis( stylized, 2, 0) # writer expects channel first image writer.add_image('stylized_img', stylized, len(train_loader) * epoch + batch_id + 1) if training_config[ 'console_log_freq'] is not None and batch_id % training_config[ 'console_log_freq'] == 0: print( f'time elapsed={(time.time() - ts) / 60:.2f}[min]|epoch={epoch + 1}|batch=[{batch_id + 1}/{len(train_loader)}]|c-loss={acc_content_loss / training_config["console_log_freq"]}|s-loss={acc_style_loss / training_config["console_log_freq"]}|tv-loss={acc_tv_loss / training_config["console_log_freq"]}|total loss={(acc_content_loss + acc_style_loss + acc_tv_loss) / training_config["console_log_freq"]}' ) acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.] if training_config['checkpoint_freq'] is not None and ( batch_id + 1) % training_config['checkpoint_freq'] == 0: training_state = utils.get_training_metadata(training_config) training_state["state_dict"] = transformer_net.state_dict() training_state["optimizer_state"] = optimizer.state_dict() ckpt_model_name = f"ckpt_style_{training_config['style_img_name'].split('.')[0]}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}_epoch_{epoch}_batch_{batch_id}.pth" torch.save( training_state, os.path.join(training_config['checkpoints_path'], ckpt_model_name)) # # Save model with additional metadata - like which commit was used to train the model, style/content weights, etc. # training_state = utils.get_training_metadata(training_config) training_state["state_dict"] = transformer_net.state_dict() training_state["optimizer_state"] = optimizer.state_dict() model_name = f"style_{training_config['style_img_name'].split('.')[0]}_datapoints_{training_state['num_of_datapoints']}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}.pth" torch.save( training_state, os.path.join(training_config['model_binaries_path'], model_name))