model_path_list = [ prefix + 'moment_alignment_model_0_1_2_0' + suffix, prefix + 'moment_alignment_model_0_1_3_0' + suffix, prefix + 'moment_alignment_model_0_1_4_0' + suffix, prefix + 'moment_alignment_model_0_1_5_0' + suffix, ] model_configuration = { 'do_incremental_training': False, 'use_very_shallow_model_list': [False for i in range(len(model_path_list))], 'load_model': False } model_list = [None for _ in range(len(model_path_list))] for i in range(len(model_list)): moment_alignment_model = net.get_moment_alignment_model(model_configuration, moment_mode=i + 2, use_list=True, list_index=i) checkpoint = torch.load(model_path_list[i], map_location='cpu') moment_alignment_model.load_state_dict(checkpoint) model_list[i] = moment_alignment_model.to(device) prefix = '/Users/Johannes/Desktop/encoder_decoder_exp/' if local else prefix suffix = '.pth' encoder_decoder_model_paths = { 'encoder_model_path': prefix + 'encoder_1_25_6_state_dict' + suffix, 'decoder_model_path': prefix + 'decoder_1_25_6_state_dict' + suffix, } encoder = net.get_trained_encoder(encoder_decoder_model_paths) decoder = net.get_trained_encoder(encoder_decoder_model_paths) encoder.to(device)
def test(configuration): """ test the moment alignment solution (for 2 moments) in comparison to the analytical solution that can be computed for mean and std :param configuration: the config file :return: """ encoder = net.get_trained_encoder(configuration) decoder = net.get_trained_decoder(configuration) analytical_ada_in_module = net.AdaptiveInstanceNormalization() pretrained_model_path = configuration['pretrained_model_path'] content_images_path = configuration['content_images_path'] style_images_path = configuration['style_images_path'] loader = configuration['loader'] unloader = configuration['unloader'] image_test_saving_path = configuration['image_saving_path'] moment_mode = configuration['moment_mode'] print('loading the moment alignment model from {}'.format( pretrained_model_path)) moment_alignment_model = net.get_moment_alignment_model( configuration, moment_mode) print(moment_alignment_model) checkpoint = torch.load(pretrained_model_path, map_location=device) moment_alignment_model.load_state_dict(checkpoint) number_content_images = len(os.listdir(content_images_path)) number_style_images = len(os.listdir(style_images_path)) content_image_files = [ '{}/{}'.format(content_images_path, os.listdir(content_images_path)[i]) for i in range(number_content_images) ] style_image_files = [ '{}/{}'.format(style_images_path, os.listdir(style_images_path)[i]) for i in range(number_style_images) ] for i in range(number_style_images): print("test_image {} at {}".format(i + 1, style_image_files[i])) for i in range(number_content_images): print("test_image {} at {}".format(i + 1, content_image_files[i])) iterations = 0 mean_percentages = [0, 0, 0, 0, 0, 0, 0] for j in range(number_content_images): for i in range(number_style_images): style_image = data_loader.image_loader(style_image_files[i], loader) content_image = data_loader.image_loader(content_image_files[j], loader) with torch.no_grad(): content_feature_maps = encoder(content_image)['r41'] style_feature_maps = encoder(style_image)['r41'] content_feature_map_batch_loader = data_loader.get_batch( content_feature_maps, 512) style_feature_map_batch_loader = data_loader.get_batch( style_feature_maps, 512) content_feature_map_batch = next( content_feature_map_batch_loader).to(device) style_feature_map_batch = next( style_feature_map_batch_loader).to(device) style_feature_map_batch_moments = u.compute_moments_batches( style_feature_map_batch) content_feature_map_batch_moments = u.compute_moments_batches( content_feature_map_batch) out = moment_alignment_model( content_feature_map_batch, content_feature_map_batch_moments, style_feature_map_batch_moments) result_feature_maps = out analytical_feature_maps = analytical_ada_in_module( content_feature_map_batch, style_feature_map_batch) a_0, a_001, a_001_l, a_01, a_01_l, a_1, a_1_l = \ get_distance(analytical_feature_maps, result_feature_maps) iterations += 1 mean_percentages[0] += a_0 mean_percentages[1] += a_001 mean_percentages[2] += a_001_l mean_percentages[3] += a_01 mean_percentages[4] += a_01_l mean_percentages[5] += a_1 mean_percentages[6] += a_1_l # u.imshow(decoder(analytical_feature_maps.view(1, 512, 32, 32)), transforms.ToPILImage()) utils.save_image([ data_loader.imnorm(content_image, unloader), data_loader.imnorm(style_image, unloader), data_loader.imnorm( decoder(result_feature_maps.view(1, 512, 32, 32)), None), data_loader.imnorm( decoder(analytical_feature_maps.view(1, 512, 32, 32)), None) ], '{}/A_image_{}_{}.jpeg'.format( image_test_saving_path, i, j), normalize=False, pad_value=1) utils.save_image([ data_loader.imnorm(content_image, unloader), data_loader.imnorm( decoder(result_feature_maps.view(1, 512, 32, 32)), None), data_loader.imnorm( decoder(analytical_feature_maps.view(1, 512, 32, 32)), None) ], '{}/B_image_{}_{}.jpeg'.format( image_test_saving_path, i, j), normalize=False, pad_value=1) utils.save_image([ data_loader.imnorm(style_image, unloader), data_loader.imnorm( decoder(result_feature_maps.view(1, 512, 32, 32)), None), data_loader.imnorm( decoder(analytical_feature_maps.view(1, 512, 32, 32)), None) ], '{}/C_image_{}_{}.jpeg'.format( image_test_saving_path, i, j), normalize=False, pad_value=1) utils.save_image([ data_loader.imnorm( decoder(result_feature_maps.view(1, 512, 32, 32)), None), data_loader.imnorm( decoder(analytical_feature_maps.view(1, 512, 32, 32)), None) ], '{}/D_image_{}_{}.jpeg'.format( image_test_saving_path, i, j), normalize=False, pad_value=1) print('averaging percentages') mean_percentages = [ mean_percentages[i] / iterations for i in range(len(mean_percentages)) ] print(mean_percentages)
def test(configuration): """ test loop to produce images with multiple models :param configuration: :return: """ encoder = net.get_trained_encoder(configuration) decoder = net.get_trained_decoder(configuration) model_path_list = configuration['model_path_list'] content_images_path = configuration['content_images_path'] style_images_path = configuration['style_images_path'] loader = configuration['loader'] image_saving_path = configuration['image_saving_path'] mode_list = configuration['mode_list'] model_list = [0 for _ in range(len(model_path_list))] for i in range(len(model_list)): moment_alignment_model = net.get_moment_alignment_model( configuration, moment_mode=mode_list[i], use_list=True, list_index=i) checkpoint = torch.load(model_path_list[i], map_location='cpu') moment_alignment_model.load_state_dict(checkpoint) model_list[i] = moment_alignment_model number_content_images = len(os.listdir(content_images_path)) number_style_images = len(os.listdir(style_images_path)) content_image_files = [ '{}/{}'.format(content_images_path, sorted(os.listdir(content_images_path))[i]) for i in range(number_content_images) ] style_image_files = [ '{}/{}'.format(style_images_path, sorted(os.listdir(style_images_path))[i]) for i in range(number_style_images) ] for i in range(number_content_images): for j in range(number_style_images): print('at image {}'.format(i)) with torch.no_grad(): content_image = data_loader.image_loader( content_image_files[i], loader) style_image = data_loader.image_loader(style_image_files[j], loader) result_images = [0 for _ in range(len(model_list))] for k in range(len(model_list)): content_feature_map_batch_loader = data_loader.get_batch( encoder(content_image)['r41'], 512) style_feature_map_batch_loader = data_loader.get_batch( encoder(style_image)['r41'], 512) content_feature_map_batch = next( content_feature_map_batch_loader).to(device) style_feature_map_batch = next( style_feature_map_batch_loader).to(device) style_feature_map_batch_moments = utils.compute_moments_batches( style_feature_map_batch, last_moment=7) content_feature_map_batch_moments = utils.compute_moments_batches( content_feature_map_batch, last_moment=7) result_images[k] = decoder(model_list[k]( content_feature_map_batch, content_feature_map_batch_moments, style_feature_map_batch_moments).view(1, 512, 32, 32)) result_images[k] = result_images[k].squeeze(0) # save all images in one row u.save_image( [ data_loader.imnorm(content_image, None), data_loader.imnorm(style_image, None) ] + result_images, '{}/moment_alignment_test_image_A_{}_{}.jpeg'.format( image_saving_path, i, j), normalize=False, scale_each=False, pad_value=1) # save all images in two rows u.save_image( [ data_loader.imnorm(content_image, None), data_loader.imnorm(style_image, None) ] + [torch.ones(3, 256, 256) for _ in range(2)] + result_images, '{}/moment_alignment_test_image_B_{}_{}.jpeg'.format( image_saving_path, i, j), normalize=False, scale_each=False, pad_value=1, nrow=4) # save all result images in one row u.save_image( result_images, '{}/moment_alignment_test_image_C_{}_{}.jpeg'.format( image_saving_path, i, j), normalize=False, scale_each=False, pad_value=1) # save all result images in one row + content image u.save_image( [data_loader.imnorm(content_image, None)] + result_images, '{}/moment_alignment_test_image_D_{}_{}.jpeg'.format( image_saving_path, i, j), normalize=False, scale_each=False, pad_value=1) # save all result images in one row + style image u.save_image( [data_loader.imnorm(style_image, None)] + result_images, '{}/moment_alignment_test_image_E_{}_{}.jpeg'.format( image_saving_path, i, j), normalize=False, scale_each=False, pad_value=1)
def test(configuration): """ test loop :param configuration: the config file :return: """ analytical_ada_in_module = net.AdaptiveInstanceNormalization() encoder = net.get_trained_encoder(configuration) decoder = net.get_trained_decoder(configuration) pretrained_model_path = configuration['pretrained_model_path'] print('loading the moment alignment model from {}'.format(pretrained_model_path)) content_images_path = configuration['content_images_path'] style_images_path = configuration['style_images_path'] loader = configuration['loader'] unloader = configuration['unloader'] image_saving_path = configuration['image_saving_path'] moment_mode = configuration['moment_mode'] moment_alignment_model = net.get_moment_alignment_model(configuration, moment_mode) print(moment_alignment_model) checkpoint = torch.load(pretrained_model_path, map_location=device) moment_alignment_model.load_state_dict(checkpoint) aligned_moment_loss = net.get_loss(configuration, moment_mode=moment_mode, lambda_1=0, lambda_2=10) number_content_images = len(os.listdir(content_images_path)) number_style_images = len(os.listdir(style_images_path)) content_image_files = ['{}/{}'.format(content_images_path, os.listdir(content_images_path)[i]) for i in range(number_content_images)] style_image_files = ['{}/{}'.format(style_images_path, os.listdir(style_images_path)[i]) for i in range(number_style_images)] for i in range(number_style_images): print("test_image {} at {}".format(i + 1, style_image_files[i])) for i in range(number_content_images): print("test_image {} at {}".format(i + 1, content_image_files[i])) for j in range(number_content_images): for i in range(number_style_images): style_image = data_loader.image_loader(style_image_files[i], loader) content_image = data_loader.image_loader(content_image_files[j], loader) with torch.no_grad(): content_feature_maps = encoder(content_image)['r41'] style_feature_maps = encoder(style_image)['r41'] content_feature_map_batch_loader = data_loader.get_batch(content_feature_maps, 512) style_feature_map_batch_loader = data_loader.get_batch(style_feature_maps, 512) content_feature_map_batch = next(content_feature_map_batch_loader).to(device) style_feature_map_batch = next(style_feature_map_batch_loader).to(device) if use_MA_module: style_feature_map_batch_moments = u.compute_moments_batches(style_feature_map_batch, last_moment=7) content_feature_map_batch_moments = u.compute_moments_batches(content_feature_map_batch, last_moment=7) out = moment_alignment_model(content_feature_map_batch, content_feature_map_batch_moments, style_feature_map_batch_moments, is_test=True) out_feature_map_batch_moments = u.compute_moments_batches(out, last_moment=7) print_some_moments(style_feature_map_batch_moments, content_feature_map_batch_moments, out_feature_map_batch_moments) loss, moment_loss, reconstruction_loss = aligned_moment_loss(content_feature_map_batch, style_feature_map_batch, content_feature_map_batch_moments, style_feature_map_batch_moments, out, is_test=True) print('loss: {}, moment_loss: {}, reconstruction_loss:{}'.format( loss.item(), moment_loss.item(), reconstruction_loss.item())) else: analytical_feature_maps = analytical_ada_in_module(content_feature_map_batch, style_feature_map_batch) out = analytical_feature_maps utils.save_image([data_loader.imnorm(content_image, unloader), data_loader.imnorm(style_image, unloader), data_loader.imnorm(decoder(out.view(1, 512, 32, 32)), None)], '{}/A_image_{}_{}.jpeg'.format(image_saving_path, i,j), normalize=False) utils.save_image([data_loader.imnorm(decoder(out.view(1, 512, 32, 32)), None)], '{}/B_image_{}_{}.jpeg'.format(image_saving_path, i, j), normalize=False)
def train(configuration): """ this is the main training loop :param configuration: the config file :return: """ epochs = configuration['epochs'] print('going to train for {} epochs'.format(epochs)) step_printing_interval = configuration['step_printing_interval'] print('writing to console every {} steps'.format(step_printing_interval)) image_saving_interval = configuration['image_saving_interval'] print('writing to console every {} steps'.format(step_printing_interval)) epoch_saving_interval = configuration['epoch_saving_interval'] print('saving the model every {} epochs'.format(epoch_saving_interval)) validation_interval = configuration['validation_interval'] print('validating the model every {} epochs'.format(validation_interval)) image_saving_path = configuration['image_saving_path'] print('saving images to {}'.format(image_saving_path)) loader = configuration['loader'] model_saving_path = configuration['model_saving_path'] print('saving models to {}'.format(model_saving_path)) # tensorboardX_path = configuration['tensorboardX_path'] # writer = SummaryWriter(logdir='{}/runs'.format(tensorboardX_path)) # print('saving tensorboardX logs to {}'.format(tensorboardX_path)) loss_writer = LossWriter(os.path.join( configuration['folder_structure'].get_parent_folder(), './loss/loss'), buffer_size=100) loss_writer.write_header(columns=[ 'epoch', 'all_training_iteration', 'loss', 'moment_loss', 'reconstruction_loss' ]) validation_loss_writer = LossWriter( os.path.join(configuration['folder_structure'].get_parent_folder(), './loss/loss')) validation_loss_writer.write_header(columns=[ 'validation_iteration', 'loss', 'moment_loss', 'reconstruction_loss' ]) # batch_size is the number of images to sample batch_size = 1 feature_map_batch_size = int(configuration['feature_map_batch_size']) print('training in batches of {} feature maps'.format( feature_map_batch_size)) coco_data_path_train = configuration['coco_data_path_train'] painter_by_numbers_data_path_train = configuration[ 'painter_by_numbers_data_path_train'] print('using {} and {} for training'.format( coco_data_path_train, painter_by_numbers_data_path_train)) coco_data_path_val = configuration['coco_data_path_val'] painter_by_numbers_data_path_val = configuration[ 'painter_by_numbers_data_path_val'] print('using {} and {} for validation'.format( coco_data_path_val, painter_by_numbers_data_path_val)) train_dataloader = data_loader.get_concat_dataloader( coco_data_path_train, painter_by_numbers_data_path_train, batch_size, loader=loader) print('got train dataloader') val_dataloader = data_loader.get_concat_dataloader( coco_data_path_val, painter_by_numbers_data_path_val, batch_size, loader=loader) print('got val dataloader') lambda_1 = configuration['lambda_1'] print('lambda 1: {}'.format(lambda_1)) lambda_2 = configuration['lambda_2'] print('lambda 2: {}'.format(lambda_2)) loss_moment_mode = configuration['moment_mode'] net_moment_mode = configuration['moment_mode'] print('loss is sum of the first {} moments'.format(loss_moment_mode)) print('net accepts {} in-channels'.format(net_moment_mode)) unloader = configuration['unloader'] print('got the unloader') do_validation = configuration['do_validation'] print('doing validation: {}'.format(do_validation)) moment_alignment_model = net.get_moment_alignment_model( configuration, moment_mode=loss_moment_mode) print('got model') print(moment_alignment_model) decoder = net.get_trained_decoder(configuration) print('got decoder') print(decoder) print('params that require grad') for name, param in moment_alignment_model.named_parameters(): if param.requires_grad: print(name) criterion = net.get_loss(configuration, moment_mode=loss_moment_mode, lambda_1=lambda_1, lambda_2=lambda_2) print('got moment loss module') print(criterion) encoder = net.get_trained_encoder(configuration) print('got encoder') print(encoder) try: optimizer = optim.Adam(moment_alignment_model.parameters(), lr=configuration['lr']) except: optimizer = optim.Adam(moment_alignment_model.module.parameters(), lr=configuration['lr']) print('got optimizer') schedule = QuadraticSchedule(timesteps=10000000, initial=configuration['lr'], final=configuration['lr'] / 10.) print('got schedule') print('making iterable from train dataloader') train_data_loader = iter(train_dataloader) print('train data loader iterable') outer_training_iteration = -1 all_training_iteration = 0 number_of_validation = 0 current_validation_loss = float('inf') for epoch in range(1, epochs): print('epoch: {}'.format(epoch)) # this is the outer training loop (sampling images) print('training model ...') while True: try: data = train_data_loader.__next__() outer_training_iteration += 1 except StopIteration: print('got to the end of the dataloader (StopIteration)') train_data_loader = iter(train_dataloader) break except: print('something went wrong with the dataloader, continuing') continue if do_validation: # validate the model every validation_interval iterations if outer_training_iteration % validation_interval == 0: print('making iterable from val dataloader') val_data_loader = iter(val_dataloader) print('val data loader iterable') validation_loss = validate(number_of_validation, criterion, encoder, moment_alignment_model, val_data_loader, feature_map_batch_size, validation_loss_writer) number_of_validation += 1 if validation_loss < current_validation_loss: utils.save_current_best_model( epoch, moment_alignment_model, configuration['model_saving_path']) print('got a better model') current_validation_loss = validation_loss print('set the new validation loss to the current one') else: print('this model is actually worse than the best one') # get the content_image batch content_image = data.get('coco').get('image') content_image = content_image.to(device) # get the style_image batch style_image = data.get('painter_by_numbers').get('image') style_image = style_image.to(device) style_feature_maps = encoder(style_image)['r41'].to(device) content_feature_maps = encoder(content_image)['r41'].to(device) result_feature_maps = torch.zeros(1, 1, 32, 32) content_feature_map_batch_loader = data_loader.get_batch( content_feature_maps, feature_map_batch_size) style_feature_map_batch_loader = data_loader.get_batch( style_feature_maps, feature_map_batch_size) # this is the inner training loop (feature maps) while True: try: content_feature_map_batch = next( content_feature_map_batch_loader).to(device) style_feature_map_batch = next( style_feature_map_batch_loader).to(device) all_training_iteration += 1 except StopIteration: break except: continue do_print = all_training_iteration % step_printing_interval == 0 optimizer.zero_grad() style_feature_map_batch_moments = utils.compute_moments_batches( style_feature_map_batch, last_moment=net_moment_mode) content_feature_map_batch_moments = utils.compute_moments_batches( content_feature_map_batch, last_moment=net_moment_mode) out = moment_alignment_model( content_feature_map_batch, content_feature_map_batch_moments, style_feature_map_batch_moments) loss, moment_loss, reconstruction_loss = criterion( content_feature_map_batch, style_feature_map_batch, content_feature_map_batch_moments, style_feature_map_batch_moments, out, last_moment=loss_moment_mode) if do_print: set_lr(optimizer, lr=schedule.get(all_training_iteration / step_printing_interval)) loss_writer.write_row([ epoch, all_training_iteration, loss.item(), moment_loss.item(), reconstruction_loss.item() ]) # backprop loss.backward() optimizer.step() result_feature_maps = torch.cat( [result_feature_maps, out.cpu().view(1, -1, 32, 32)], 1) # if do_print: # print('loss: {:4f}'.format(loss.item())) # # writer.add_scalar('data/training_loss', loss.item(), all_training_iteration) # writer.add_scalar('data/training_moment_loss', moment_loss.item(), all_training_iteration) # writer.add_scalar('data/training_reconstruction_loss', reconstruction_loss.item(), # all_training_iteration) result_feature_maps = result_feature_maps[:, 1:513, :, :] result_img = decoder(result_feature_maps.to(device)) if outer_training_iteration % image_saving_interval == 0: u.save_image([ data_loader.imnorm(content_image, unloader), data_loader.imnorm(style_image, unloader), data_loader.imnorm(result_img, None) ], '{}/image_{}_{}__{}_{}.jpeg'.format( image_saving_path, epoch, outer_training_iteration / epoch, lambda_1, lambda_2), normalize=False) # save every epoch_saving_interval the current model if outer_training_iteration % image_saving_interval == 0: utils.save_current_model(lambda_1, lambda_2, moment_alignment_model.state_dict(), optimizer.state_dict(), configuration['model_saving_path'])