def load_model(config, checkpoint_file): ''' Loads a pretrained model from disk. Args: config: dict, configuration file checkpoint_file: str, checkpoint filename Returns: model: HRNet, a pytorch model ''' # checkpoint_dir = config["paths"]["checkpoint_dir"] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = HRNet(config["network"]).to(device) model.load_state_dict(torch.load(checkpoint_file)) return model
def main(config): """ Given a configuration, trains HRNet and ShiftNet for Multi-Frame Super Resolution (MFSR), and saves best model. Args: config: dict, configuration file """ # Reproducibility options np.random.seed(0) # RNG seeds torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Initialize the network based on the network configuration fusion_model = HRNet(config["network"]) regis_model = ShiftNet() optimizer = optim.Adam(list(fusion_model.parameters()) + list(regis_model.parameters()), lr=config["training"]["lr"]) # optim # ESA dataset data_directory = config["paths"]["prefix"] baseline_cpsnrs = None if os.path.exists(os.path.join(data_directory, "norm.csv")): baseline_cpsnrs = readBaselineCPSNR( os.path.join(data_directory, "norm.csv")) train_set_directories = getImageSetDirectories( os.path.join(data_directory, "train")) val_proportion = config['training']['val_proportion'] train_list, val_list = train_test_split(train_set_directories, test_size=val_proportion, random_state=1, shuffle=True) # Dataloaders batch_size = config["training"]["batch_size"] n_workers = config["training"]["n_workers"] n_views = config["training"]["n_views"] min_L = config["training"]["min_L"] # minimum number of views beta = config["training"]["beta"] train_dataset = ImagesetDataset(imset_dir=train_list, config=config["training"], top_k=n_views, beta=beta) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, collate_fn=collateFunction(min_L=min_L), pin_memory=True) config["training"]["create_patches"] = False val_dataset = ImagesetDataset(imset_dir=val_list, config=config["training"], top_k=n_views, beta=beta) val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=collateFunction(min_L=min_L), pin_memory=True) dataloaders = {'train': train_dataloader, 'val': val_dataloader} # Train model torch.cuda.empty_cache() #fusion_model.load_state_dict(torch.load("/home/ubadmin/Documents/Scripts/highres_net/HighRes-net-master/models/weights/training_8b_full_ESA/HRNet.pth")) #regis_model.load_state_dict(torch.load("/home/ubadmin/Documents/Scripts/highres_net/HighRes-net-master/models/weights/training_8b_full_ESA/ShiftNet.pth")) trainAndGetBestModel(fusion_model, regis_model, optimizer, dataloaders, baseline_cpsnrs, config)