def main(_): imglob = glob.glob(os.path.join(FLAGS.data_root, FLAGS.dataset, '*')) mses = {} psnrs = {} ssims = {} experiment_folder = get_experiment_folder() # save FLAGS to yml yaml.dump(FLAGS.flag_values_dict(), open(os.path.join(experiment_folder, 'FLAGS.yml'), 'w')) for i, im in enumerate(imglob): print('Image: ' + str(i)) image_name = im.split('/')[-1].split('.')[0] img_dataset = dataio.ImageFile(im) img = PIL.Image.open(im) image_resolution = (img.size[1] // FLAGS.downscaling_factor, img.size[0] // FLAGS.downscaling_factor) #run_name = image_name + '_layers' + str(layers) + '_units' + str(hidden_units) + '_model' + FLAGS.model_type coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=image_resolution) dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=FLAGS.batch_size, pin_memory=True, num_workers=0) linear_decay = { 'img_loss': trainingAC.LinearDecaySchedule(1, 1, FLAGS.epochs) } # Define the model. if FLAGS.model_type == 'mlp': model = modules.SingleBVPNet_INR( type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset.img_channels, hidden_features=FLAGS.hidden_dims, num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale, batch_norm=FLAGS.bn, ff_dims=FLAGS.ff_dims) elif FLAGS.model_type == 'multi_tapered': model = modules.MultiScale_INR( type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset.img_channels, hidden_features=FLAGS.hidden_dims, num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale, tapered=True, downsample=False, ff_dims=FLAGS.ff_dims) elif FLAGS.model_type == 'multi': model = modules.MultiScale_INR( type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset.img_channels, hidden_features=FLAGS.hidden_dims, num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale, tapered=False, downsample=False, ff_dims=FLAGS.ff_dims) elif FLAGS.model_type == 'parallel': model = modules.Parallel_INR(type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset.img_channels, hidden_features=[ FLAGS.hidden_dims // 4, FLAGS.hidden_dims // 2, FLAGS.hidden_dims ], num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale) elif FLAGS.model_type == 'mixture': model = modules.INR_Mixture(type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset.img_channels, hidden_features=FLAGS.hidden_dims, num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale, batch_norm=FLAGS.bn, ff_dims=FLAGS.ff_dims, num_components=FLAGS.num_components) model.cuda() root_path = os.path.join(experiment_folder, image_name) # Define the loss if FLAGS.loss == 'mse': loss_fn = partial(loss_functions.image_mse, None) elif FLAGS.loss == 'log_mse': loss_fn = image_log_mse summary_fn = partial(siren_utils.write_image_summary, image_resolution) if FLAGS.model_type == 'parallel': trainingAC.train_phased( model=model, train_dataloader=dataloader, epochs=FLAGS.epochs, lr=FLAGS.lr, steps_til_summary=FLAGS.steps_til_summary, epochs_til_checkpoint=FLAGS.epochs_til_ckpt, model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, l1_reg=FLAGS.l1_reg, spec_reg=FLAGS.spec_reg, phased=FLAGS.phased, intermediate_losses=FLAGS.intermediate_losses) else: trainingAC.train(model=model, train_dataloader=dataloader, epochs=FLAGS.epochs, lr=FLAGS.lr, steps_til_summary=FLAGS.steps_til_summary, epochs_til_checkpoint=FLAGS.epochs_til_ckpt, model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, l1_reg=FLAGS.l1_reg, spec_reg=FLAGS.spec_reg, loss_schedules=None) mse, ssim, psnr = utils.check_metrics_full(dataloader, model, image_resolution) mses[image_name] = mse psnrs[image_name] = psnr ssims[image_name] = ssim metrics = {'mse': mses, 'psnr': psnrs, 'ssim': ssims} with open(os.path.join(experiment_folder, 'result.json'), 'w') as fp: json.dump(metrics, fp)
if 'bn' not in TRAINING_FLAGS: TRAINING_FLAGS['bn'] = False if 'intermediate_losses' not in TRAINING_FLAGS: TRAINING_FLAGS['intermediate_losses'] = False if 'phased' not in TRAINING_FLAGS: TRAINING_FLAGS['phased'] = False if 'ff_dims' not in TRAINING_FLAGS: TRAINING_FLAGS['ff_dims'] = None if 'num_components' not in TRAINING_FLAGS: TRAINING_FLAGS['num_components'] = 1 if TRAINING_FLAGS['model_type'] == 'mlp': model = modules.SingleBVPNet_INR( type=TRAINING_FLAGS['activation'], mode=TRAINING_FLAGS['encoding'], sidelength=image_resolution, out_features=img_dataset.img_channels, hidden_features=TRAINING_FLAGS['hidden_dims'], num_hidden_layers=TRAINING_FLAGS['hidden_layers'], encoding_scale=s, batch_norm=TRAINING_FLAGS['bn'], ff_dims=TRAINING_FLAGS['ff_dims']) elif TRAINING_FLAGS['model_type'] == 'multi_tapered': model = modules.MultiScale_INR( type=TRAINING_FLAGS['activation'], mode=TRAINING_FLAGS['encoding'], sidelength=image_resolution, out_features=img_dataset.img_channels, hidden_features=TRAINING_FLAGS['hidden_dims'], num_hidden_layers=TRAINING_FLAGS['hidden_layers'], encoding_scale=s, tapered=True, ff_dims=TRAINING_FLAGS['ff_dims'])
def main(_): imglob_maml = glob.glob( os.path.join(FLAGS.data_root, FLAGS.maml_dataset, '*')) imglob = glob.glob(os.path.join(FLAGS.data_root, FLAGS.dataset, '*')) mses = {} psnrs = {} ssims = {} experiment_folder = get_experiment_folder() maml_folder = get_maml_folder() # save FLAGS to yml yaml.dump(FLAGS.flag_values_dict(), open(os.path.join(experiment_folder, 'FLAGS.yml'), 'w')) img_dataset = [] for i, im in enumerate(imglob_maml): image_name = im.split('/')[-1].split('.')[0] img_dataset.append(dataio.ImageFile(im)) img = PIL.Image.open(im) image_resolution = (img.size[1] // FLAGS.downscaling_factor, img.size[0] // FLAGS.downscaling_factor) # run_name = image_name + '_layers' + str(layers) + '_units' + str(hidden_units) + '_model' + FLAGS.model_type coord_dataset = dataio.Implicit2DListWrapper(img_dataset, sidelength=image_resolution) dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=FLAGS.batch_size, pin_memory=True, num_workers=0) # linear_decay = {'img_loss': trainingAC.LinearDecaySchedule(1, 1, FLAGS.epochs)} # Define the model. if FLAGS.model_type == 'mlp': model = modules.SingleBVPNet_INR( type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset[0].img_channels, hidden_features=FLAGS.hidden_dims, num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale, batch_norm=FLAGS.bn, ff_dims=FLAGS.ff_dims) elif FLAGS.model_type == 'multi_tapered': model = modules.MultiScale_INR( type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset[0].img_channels, hidden_features=FLAGS.hidden_dims, num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale, tapered=True, downsample=False, ff_dims=FLAGS.ff_dims) elif FLAGS.model_type == 'multi': model = modules.MultiScale_INR( type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset[0].img_channels, hidden_features=FLAGS.hidden_dims, num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale, tapered=False, downsample=False, ff_dims=FLAGS.ff_dims) elif FLAGS.model_type == 'parallel': model = modules.Parallel_INR(type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset[0].img_channels, hidden_features=[ FLAGS.hidden_dims // 4, FLAGS.hidden_dims // 2, FLAGS.hidden_dims ], num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale) elif FLAGS.model_type == 'mixture': model = modules.INR_Mixture(type=FLAGS.activation, mode=FLAGS.encoding, sidelength=image_resolution, out_features=img_dataset[0].img_channels, hidden_features=FLAGS.hidden_dims, num_hidden_layers=FLAGS.hidden_layers, encoding_scale=FLAGS.encoding_scale, batch_norm=FLAGS.bn, ff_dims=FLAGS.ff_dims, num_components=FLAGS.num_components) # exp_root = 'exp/maml' # experiment_names = [i.split('/')[-4] for i in # glob.glob(exp_root + '/KODAK21_epochs10000_lr0.0001_mlp_hdims100_hlayer2_mlp_sine_l1_reg0.001/maml/checkpoints/')] # state_dict = torch.load(os.path.join('KODAK21_epochs10000_lr0.0001_mlp_hdims100_hlayer2_mlp_sine_l1_reg0.001', 'maml' + '/checkpoints/model_best_.pth'), map_location='cpu').load_state_dict(state_dict, strict=True) model.cuda() root_path = maml_folder # Define the loss if FLAGS.loss == 'mse': loss_fn = partial(loss_functions.image_mse, None) elif FLAGS.loss == 'log_mse': loss_fn = image_log_mse summary_fn = partial(siren_utils.write_image_summary, image_resolution) try: state_dict = torch.load(os.path.join(maml_folder, 'checkpoints/model_maml.pth'), map_location='cpu') model.load_state_dict(state_dict, strict=True) except: print("No matching model found, training from scratch.") yaml.dump(FLAGS.flag_values_dict(), open(os.path.join(maml_folder, 'FLAGS.yml'), 'w')) trainingMAML.train(model=model, train_dataloader=dataloader, maml_iterations=FLAGS.maml_iterations, inner_lr=FLAGS.inner_lr, outer_lr=FLAGS.outer_lr, steps_til_summary=FLAGS.steps_til_summary, epochs_til_checkpoint=FLAGS.epochs_til_ckpt, model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, maml_batch_size=FLAGS.maml_batch_size, maml_adaptation_steps=FLAGS.maml_adaptation_steps) ref_model = copy.deepcopy(model) l1_loss_fn = partial(losses.model_l1_diff, ref_model) torch.save(model.state_dict(), os.path.join(experiment_folder, 'model_maml.pth')) for i, im in enumerate(imglob): print('Image: ' + str(i)) image_name = im.split('/')[-1].split('.')[0] img_dataset = dataio.ImageFile(im) img = PIL.Image.open(im) image_resolution = (img.size[1] // FLAGS.downscaling_factor, img.size[0] // FLAGS.downscaling_factor) # run_name = image_name + '_layers' + str(layers) + '_units' + str(hidden_units) + '_model' + FLAGS.model_type coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=image_resolution) root_path = os.path.join(experiment_folder, image_name) dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=FLAGS.batch_size, pin_memory=True, num_workers=0) trainingAC.train(model=model, train_dataloader=dataloader, epochs=FLAGS.epochs, lr=FLAGS.lr, steps_til_summary=FLAGS.steps_til_summary, epochs_til_checkpoint=FLAGS.epochs_til_ckpt, model_dir=root_path, loss_fn=loss_fn, l1_loss_fn=l1_loss_fn, summary_fn=summary_fn, l1_reg=FLAGS.l1_reg, spec_reg=FLAGS.spec_reg) mse, ssim, psnr = utils.check_metrics_full(dataloader, model, image_resolution) # mses[image_name] = mse # psnrs[image_name] = psnr # ssims[image_name] = ssim metrics = {'mse': mse, 'psnr': psnr, 'ssim': ssim} with open(os.path.join(root_path, 'result.json'), 'w') as fp: json.dump(metrics, fp)