Beispiel #1
0
def load_model(config, checkpoint_file):
    '''
    Loads a pretrained model from disk.
    Args:
        config: dict, configuration file
        checkpoint_file: str, checkpoint filename
    Returns:
        model: HRNet, a pytorch model
    '''

    #     checkpoint_dir = config["paths"]["checkpoint_dir"]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = HRNet(config["network"]).to(device)
    model.load_state_dict(torch.load(checkpoint_file))
    return model
Beispiel #2
0
def main(config):
    """
    Given a configuration, trains HRNet and ShiftNet for Multi-Frame Super Resolution (MFSR), and saves best model.
    Args:
        config: dict, configuration file
    """

    # Reproducibility options
    np.random.seed(0)  # RNG seeds
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Initialize the network based on the network configuration

    fusion_model = HRNet(config["network"])
    regis_model = ShiftNet()

    optimizer = optim.Adam(list(fusion_model.parameters()) +
                           list(regis_model.parameters()),
                           lr=config["training"]["lr"])  # optim
    # ESA dataset
    data_directory = config["paths"]["prefix"]

    baseline_cpsnrs = None
    if os.path.exists(os.path.join(data_directory, "norm.csv")):
        baseline_cpsnrs = readBaselineCPSNR(
            os.path.join(data_directory, "norm.csv"))

    train_set_directories = getImageSetDirectories(
        os.path.join(data_directory, "train"))

    val_proportion = config['training']['val_proportion']
    train_list, val_list = train_test_split(train_set_directories,
                                            test_size=val_proportion,
                                            random_state=1,
                                            shuffle=True)

    # Dataloaders
    batch_size = config["training"]["batch_size"]
    n_workers = config["training"]["n_workers"]
    n_views = config["training"]["n_views"]
    min_L = config["training"]["min_L"]  # minimum number of views
    beta = config["training"]["beta"]

    train_dataset = ImagesetDataset(imset_dir=train_list,
                                    config=config["training"],
                                    top_k=n_views,
                                    beta=beta)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=n_workers,
                                  collate_fn=collateFunction(min_L=min_L),
                                  pin_memory=True)

    config["training"]["create_patches"] = False
    val_dataset = ImagesetDataset(imset_dir=val_list,
                                  config=config["training"],
                                  top_k=n_views,
                                  beta=beta)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=n_workers,
                                collate_fn=collateFunction(min_L=min_L),
                                pin_memory=True)

    dataloaders = {'train': train_dataloader, 'val': val_dataloader}

    # Train model
    torch.cuda.empty_cache()

    #fusion_model.load_state_dict(torch.load("/home/ubadmin/Documents/Scripts/highres_net/HighRes-net-master/models/weights/training_8b_full_ESA/HRNet.pth"))
    #regis_model.load_state_dict(torch.load("/home/ubadmin/Documents/Scripts/highres_net/HighRes-net-master/models/weights/training_8b_full_ESA/ShiftNet.pth"))

    trainAndGetBestModel(fusion_model, regis_model, optimizer, dataloaders,
                         baseline_cpsnrs, config)