def set_hyperparameters(params,
                        num_classes,
                        model,
                        checkpoint,
                        dontcare_val,
                        loss_fn,
                        optimizer,
                        class_weights=None,
                        inference: str = ''):
    """
    Function to set hyperparameters based on values provided in yaml config file.
    If none provided, default functions values may be used.
    :param params: (dict) Parameters found in the yaml config file
    :param num_classes: (int) number of classes for current task
    :param model: initialized model
    :param checkpoint: (dict) state dict as loaded by model_choice.py
    :param dontcare_val: value in label to ignore during loss calculation
    :param loss_fn: loss function
    :param optimizer: optimizer function
    :param class_weights: class weights for loss function
    :param inference: (str) path to inference checkpoint (used in load_from_checkpoint())
    :return: model, criterion, optimizer, lr_scheduler, num_gpus
    """
    # set mandatory hyperparameters values with those in config file if they exist
    lr = get_key_def('learning_rate', params['training'], None)
    weight_decay = get_key_def('weight_decay', params['training'], None)
    step_size = get_key_def('step_size', params['training'], None)
    gamma = get_key_def('gamma', params['training'], None)

    class_weights = torch.tensor(class_weights) if class_weights else None

    # Loss function
    criterion = MultiClassCriterion(loss_type=loss_fn,
                                    ignore_index=dontcare_val,
                                    weight=class_weights)

    # Optimizer
    opt_fn = optimizer
    optimizer = create_optimizer(params=model.parameters(),
                                 mode=opt_fn,
                                 base_lr=lr,
                                 weight_decay=weight_decay)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer,
                                             step_size=step_size,
                                             gamma=gamma)

    if checkpoint:
        tqdm.write(f'Loading checkpoint...')
        model, optimizer = load_from_checkpoint(checkpoint,
                                                model,
                                                optimizer=optimizer,
                                                inference=inference)

    return model, criterion, optimizer, lr_scheduler
Exemple #2
0
def set_hyperparameters(params, num_classes, model, checkpoint):
    """
    Function to set hyperparameters based on values provided in yaml config file.
    Will also set model to GPU, if available.
    If none provided, default functions values may be used.
    :param params: (dict) Parameters found in the yaml config file
    :param num_classes: (int) number of classes for current task
    :param model: Model loaded from model_choice.py
    :param checkpoint: (dict) state dict as loaded by model_choice.py
    :return: model, criterion, optimizer, lr_scheduler, num_gpus
    """
    # set mandatory hyperparameters values with those in config file if they exist
    lr = get_key_def('learning_rate', params['training'], None,
                     "missing mandatory learning rate parameter")
    weight_decay = get_key_def('weight_decay', params['training'], None,
                               "missing mandatory weight decay parameter")
    step_size = get_key_def('step_size', params['training'], None,
                            "missing mandatory step size parameter")
    gamma = get_key_def('gamma', params['training'], None,
                        "missing mandatory gamma parameter")

    # optional hyperparameters. Set to None if not in config file
    class_weights = torch.tensor(
        params['training']
        ['class_weights']) if params['training']['class_weights'] else None
    if params['training']['class_weights']:
        verify_weights(num_classes, class_weights)
    ignore_index = get_key_def('ignore_index', params['training'], -1)

    # Loss function
    criterion = MultiClassCriterion(loss_type=params['training']['loss_fn'],
                                    ignore_index=ignore_index,
                                    weight=class_weights)

    # Optimizer
    opt_fn = params['training']['optimizer']
    optimizer = create_optimizer(params=model.parameters(),
                                 mode=opt_fn,
                                 base_lr=lr,
                                 weight_decay=weight_decay)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer,
                                             step_size=step_size,
                                             gamma=gamma)

    if checkpoint:
        tqdm.write(f'Loading checkpoint...')
        model, optimizer = load_from_checkpoint(checkpoint,
                                                model,
                                                optimizer=optimizer)

    return model, criterion, optimizer, lr_scheduler
Exemple #3
0
def create_dataloader(samples_folder, batch_size, num_devices, params):
    """
    Function to create dataloader objects for training, validation and test datasets.
    :param samples_folder: path to the folder containting .hdf5 files if task is segmentation
    :param batch_size: (int) batch size
    :param num_devices: (int) number of GPUs used
    :param params: (dict) Parameters found in the yaml config file.
    :return: trn_dataloader, val_dataloader, tst_dataloader
    """
    debug = get_key_def('debug_mode', params['global'], False)
    dontcare_val = get_key_def("ignore_index", params["training"], -1)

    assert samples_folder.is_dir(), f'Could not locate: {samples_folder}'
    assert len([f for f in samples_folder.glob('**/*.hdf5')]) >= 1, f"Couldn't locate .hdf5 files in {samples_folder}"
    num_samples = get_num_samples(samples_path=samples_folder, params=params)
    assert num_samples['trn'] >= batch_size and num_samples['val'] >= batch_size, f"Number of samples in .hdf5 files is less than batch size"
    print(f"Number of samples : {num_samples}\n")
    meta_map = get_key_def("meta_map", params["global"], {})
    num_bands = get_key_def("number_of_bands", params["global"], {})
    if not meta_map:
        dataset_constr = CreateDataset.SegmentationDataset
    else:
        dataset_constr = functools.partial(CreateDataset.MetaSegmentationDataset, meta_map=meta_map)
    datasets = []

    for subset in ["trn", "val", "tst"]:
        # TODO: transform the aug.compose_transforms in a class with radiom, geom, totensor in def
        datasets.append(dataset_constr(samples_folder, subset, num_bands,
                                       max_sample_count=num_samples[subset],
                                       dontcare=dontcare_val,
                                       radiom_transform=aug.compose_transforms(params, subset, type='radiometric'),
                                       geom_transform=aug.compose_transforms(params, subset, type='geometric',
                                                                             ignore_index=dontcare_val),
                                       totensor_transform=aug.compose_transforms(params, subset, type='totensor'),
                                       debug=debug))
    trn_dataset, val_dataset, tst_dataset = datasets

    # https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/5
    num_workers = num_devices * 4 if num_devices > 1 else 4

    # Shuffle must be set to True.
    trn_dataloader = DataLoader(trn_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True,
                                drop_last=True)
    # Using batch_metrics with shuffle=False on val dataset will always mesure metrics on the same portion of the val samples.
    # Shuffle should be set to True.
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True,
                                drop_last=True)
    tst_dataloader = DataLoader(tst_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False,
                                drop_last=True) if num_samples['tst'] > 0 else None

    return trn_dataloader, val_dataloader, tst_dataloader
Exemple #4
0
def get_num_samples(samples_path, params, dontcare):
    """
    Function to retrieve number of samples, either from config file or directly from hdf5 file.
    :param samples_path: (str) Path to samples folder
    :param params: (dict) Parameters found in the yaml config file.
    :param dontcare:
    :return: (dict) number of samples for trn, val and tst.
    """
    num_samples = {'trn': 0, 'val': 0, 'tst': 0}
    weights = []
    samples_weight = None
    for i in ['trn', 'val', 'tst']:
        if get_key_def(f"num_{i}_samples", params['training'], None) is not None:
            num_samples[i] = params['training'][f"num_{i}_samples"]

            with h5py.File(samples_path.joinpath(f"{i}_samples.hdf5"), 'r') as hdf5_file:
                file_num_samples = len(hdf5_file['map_img'])
            if num_samples[i] > file_num_samples:
                raise IndexError(f"The number of training samples in the configuration file ({num_samples[i]}) "
                                 f"exceeds the number of samples in the hdf5 training dataset ({file_num_samples}).")
        else:
            with h5py.File(samples_path.joinpath(f"{i}_samples.hdf5"), "r") as hdf5_file:
                num_samples[i] = len(hdf5_file['map_img'])

        with h5py.File(samples_path.joinpath(f"{i}_samples.hdf5"), "r") as hdf5_file:
            if i == 'trn':
                for x in range(num_samples[i]):
                    label = hdf5_file['map_img'][x]
                    unique_labels = np.unique(label)
                    weights.append(''.join([str(int(i)) for i in unique_labels]))
                    samples_weight = compute_sample_weight('balanced', weights)

    return num_samples, samples_weight
def create_files_and_datasets(params, samples_folder):
    """
    Function to create the hdfs files (trn, val and tst).
    :param params: (dict) Parameters found in the yaml config file.
    :param samples_folder: (str) Path to the output folder.
    :return: (hdf5 datasets) trn, val ant tst datasets.
    """
    samples_size = params['global']['samples_size']
    number_of_bands = params['global']['number_of_bands']
    meta_map = get_key_def('meta_map', params['global'], {})
    real_num_bands = number_of_bands - MetaSegmentationDataset.get_meta_layer_count(
        meta_map)
    assert real_num_bands > 0, "invalid number of bands when accounting for meta layers"
    hdf5_files = []
    for subset in ["trn", "val", "tst"]:
        hdf5_file = h5py.File(
            os.path.join(samples_folder, f"{subset}_samples.hdf5"), "w")
        hdf5_file.create_dataset(
            "sat_img", (0, samples_size, samples_size, real_num_bands),
            np.float32,
            maxshape=(None, samples_size, samples_size, real_num_bands))
        hdf5_file.create_dataset("map_img", (0, samples_size, samples_size),
                                 np.int16,
                                 maxshape=(None, samples_size, samples_size))
        hdf5_file.create_dataset("meta_idx", (0, 1),
                                 dtype=np.int16,
                                 maxshape=(None, 1))
        hdf5_file.create_dataset("metadata", (0, 1),
                                 dtype=h5py.string_dtype(),
                                 maxshape=(None, 1))
        hdf5_files.append(hdf5_file)
    return hdf5_files
Exemple #6
0
def get_num_samples(samples_path, params):
    """
    Function to retrieve number of samples, either from config file or directly from hdf5 file.
    :param samples_path: (str) Path to samples folder
    :param params: (dict) Parameters found in the yaml config file.
    :return: (dict) number of samples for trn, val and tst.
    """
    num_samples = {'trn': 0, 'val': 0, 'tst': 0}

    for i in ['trn', 'val', 'tst']:
        if get_key_def(f"num_{i}_samples", params['training'], None) is not None:
            num_samples[i] = params['training'][f"num_{i}_samples"]

            with h5py.File(samples_path.joinpath(f"{i}_samples.hdf5"), 'r') as hdf5_file:
                file_num_samples = len(hdf5_file['map_img'])
            if num_samples[i] > file_num_samples:
                raise IndexError(f"The number of training samples in the configuration file ({num_samples[i]}) "
                                 f"exceeds the number of samples in the hdf5 training dataset ({file_num_samples}).")
        else:
            with h5py.File(samples_path.joinpath(f"{i}_samples.hdf5"), "r") as hdf5_file:
                num_samples[i] = len(hdf5_file['map_img'])

    return num_samples
Exemple #7
0
def main(params: dict):
    """
    Identify the class to which each image belongs.
    :param params: (dict) Parameters found in the yaml config file.

    """
    # SET BASIC VARIABLES AND PATHS
    since = time.time()

    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.')

    num_classes = params['global']['num_classes']
    task = params['global']['task']
    num_classes_corrected = add_background_to_num_class(task, num_classes)

    chunk_size = get_key_def('chunk_size', params['inference'], 512)
    overlap = get_key_def('overlap', params['inference'], 10)
    nbr_pix_overlap = int(math.floor(overlap / 100 * chunk_size))
    num_bands = params['global']['number_of_bands']

    img_dir_or_csv = params['inference']['img_dir_or_csv_file']

    default_working_folder = Path(params['inference']['state_dict_path']).parent.joinpath(f'inference_{num_bands}bands')
    working_folder = get_key_def('working_folder', params['inference'], None)
    if working_folder:  # TODO: July 2020: deprecation started. Remove custom working_folder parameter as of Sept 2020?
        working_folder = Path(working_folder)
        warnings.warn(f"Deprecated parameter. Remove it in your future yamls as this folder is now created "
                      f"automatically in a logical path, "
                      f"i.e. [state_dict_path from inference section in yaml]/inference_[num_bands]bands")
    else:
        working_folder = default_working_folder
    Path.mkdir(working_folder, exist_ok=True)
    print(f'Inferences will be saved to: {working_folder}\n\n')

    bucket = None
    bucket_file_cache = []
    bucket_name = get_key_def('bucket_name', params['global'])

    # CONFIGURE MODEL
    model, state_dict_path, model_name = net(params, num_channels=num_classes_corrected, inference=True)

    num_devices = params['global']['num_gpus'] if params['global']['num_gpus'] else 0
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(num_devices) if torch.cuda.is_available() else []
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.is_available() and lst_device_ids else 'cpu')

    if lst_device_ids:
        print(f"Number of cuda devices requested: {num_devices}. Cuda devices available: {lst_device_ids}. Using {lst_device_ids[0]}\n\n")
    else:
        warnings.warn(f"No Cuda device available. This process will only run on CPU")

    try:
        model.to(device)
    except RuntimeError:
        print(f"Unable to use device. Trying device 0")
        device = torch.device(f'cuda:0' if torch.cuda.is_available() and lst_device_ids else 'cpu')
        model.to(device)

    # CREATE LIST OF INPUT IMAGES FOR INFERENCE
    list_img = list_input_images(img_dir_or_csv, bucket_name, glob_patterns=["*.tif", "*.TIF"])

    if task == 'classification':
        classifier(params, list_img, model, device, working_folder)  # FIXME: why don't we load from checkpoint in classification?

    elif task == 'segmentation':
        if bucket:
            bucket.download_file(state_dict_path, "saved_model.pth.tar")  # TODO: is this still valid?
            model, _ = load_from_checkpoint("saved_model.pth.tar", model)
        else:
            model, _ = load_from_checkpoint(state_dict_path, model)

        ignore_index = get_key_def('ignore_index', params['training'], -1)
        meta_map, yaml_metadata = get_key_def("meta_map", params["global"], {}), None

        # LOOP THROUGH LIST OF INPUT IMAGES
        with tqdm(list_img, desc='image list', position=0) as _tqdm:
            for info in _tqdm:
                img_name = Path(info['tif']).name
                if bucket:
                    local_img = f"Images/{img_name}"
                    bucket.download_file(info['tif'], local_img)
                    inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'], info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]
                else:  # FIXME: else statement should support img['meta'] integration as well.
                    local_img = Path(info['tif'])
                    inference_image = working_folder.joinpath(f"{img_name.split('.')[0]}_inference.tif")

                assert local_img.is_file(), f"Could not open raster file at {local_img}"

                # Empty sample as dictionary
                inf_sample = {'sat_img': None, 'metadata': None}

                with rasterio.open(local_img, 'r') as raster_handle:
                    inf_sample['sat_img'], raster_handle_updated, dataset_nodata = image_reader_as_array(
                                    input_image=raster_handle,
                                    aux_vector_file=get_key_def('aux_vector_file', params['global'], None),
                                    aux_vector_attrib=get_key_def('aux_vector_attrib', params['global'], None),
                                    aux_vector_ids=get_key_def('aux_vector_ids', params['global'], None),
                                    aux_vector_dist_maps=get_key_def('aux_vector_dist_maps', params['global'], True),
                                    aux_vector_scale=get_key_def('aux_vector_scale', params['global'], None))

                inf_sample['metadata'] = add_metadata_from_raster_to_sample(sat_img_arr=inf_sample['sat_img'],
                                                                            raster_handle=raster_handle_updated,
                                                                            meta_map=meta_map,
                                                                            raster_info=info)

                _tqdm.set_postfix(OrderedDict(img_name=img_name,
                                              img=inf_sample['sat_img'].shape,
                                              img_min_val=np.min(inf_sample['sat_img']),
                                              img_max_val=np.max(inf_sample['sat_img'])))

                input_band_count = inf_sample['sat_img'].shape[2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                if input_band_count > num_bands:  # TODO: move as new function in utils.verifications
                    # FIXME: Following statements should be reconsidered to better manage inconsistencies between
                    #  provided number of band and image number of band.
                    warnings.warn(f"Input image has more band than the number provided in the yaml file ({num_bands}). "
                                  f"Will use the first {num_bands} bands of the input image.")
                    inf_sample['sat_img'] = inf_sample['sat_img'][:, :, 0:num_bands]
                    print(f"Input image's new shape: {inf_sample['sat_img'].shape}")

                elif input_band_count < num_bands:
                    warnings.warn(f"Skipping image: The number of bands requested in the yaml file ({num_bands})"
                                  f"can not be larger than the number of band in the input image ({input_band_count}).")
                    continue

                # START INFERENCES ON SUB-IMAGES
                sem_seg_results_per_class = sem_seg_inference(model,
                                                              inf_sample['sat_img'],
                                                              nbr_pix_overlap,
                                                              chunk_size,
                                                              num_classes_corrected,
                                                              device,
                                                              meta_map,
                                                              inf_sample['metadata'],
                                                              output_path=working_folder,
                                                              index=_tqdm.n,
                                                              debug=debug)

                # CREATE GEOTIF FROM METADATA OF ORIGINAL IMAGE
                tqdm.write(f'Saving inference...\n')
                if get_key_def('heatmaps', params['inference'], False):
                    tqdm.write(f'Heatmaps will be saved.\n')
                vis(params, inf_sample['sat_img'], sem_seg_results_per_class, working_folder, inference_input_path=local_img, debug=debug)

                tqdm.write(f"\n\nSemantic segmentation of image {img_name} completed\n\n")
                if bucket:
                    bucket.upload_file(inference_image, os.path.join(working_folder, f"{img_name.split('.')[0]}_inference.tif"))
    else:
        raise ValueError(
            f"The task should be either classification or segmentation. The provided value is {params['global']['task']}")

    time_elapsed = time.time() - since
    print('Inference completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Exemple #8
0
def main(params: dict):
    """
    Identify the class to which each image belongs.
    :param params: (dict) Parameters found in the yaml config file.

    """
    since = time.time()

    # MANDATORY PARAMETERS
    img_dir_or_csv = get_key_def('img_dir_or_csv_file',
                                 params['inference'],
                                 expected_type=str)
    state_dict = get_key_def('state_dict_path', params['inference'])
    task = get_key_def('task', params['global'], expected_type=str)
    if task not in ['classification', 'segmentation']:
        raise ValueError(
            f'Task should be either "classification" or "segmentation". Got {task}'
        )
    model_name = get_key_def('model_name', params['global'],
                             expected_type=str).lower()
    num_classes = get_key_def('num_classes',
                              params['global'],
                              expected_type=int)
    num_bands = get_key_def('number_of_bands',
                            params['global'],
                            expected_type=int)
    chunk_size = get_key_def('chunk_size',
                             params['inference'],
                             default=512,
                             expected_type=int)
    BGR_to_RGB = get_key_def('BGR_to_RGB',
                             params['global'],
                             expected_type=bool)

    # OPTIONAL PARAMETERS
    dontcare_val = get_key_def("ignore_index",
                               params["training"],
                               default=-1,
                               expected_type=int)
    num_devices = get_key_def('num_gpus',
                              params['global'],
                              default=0,
                              expected_type=int)
    default_max_used_ram = 25
    max_used_ram = get_key_def('max_used_ram',
                               params['global'],
                               default=default_max_used_ram,
                               expected_type=int)
    max_used_perc = get_key_def('max_used_perc',
                                params['global'],
                                default=25,
                                expected_type=int)
    scale = get_key_def('scale_data',
                        params['global'],
                        default=[0, 1],
                        expected_type=List)
    debug = get_key_def('debug_mode',
                        params['global'],
                        default=False,
                        expected_type=bool)
    raster_to_vec = get_key_def('ras2vec', params['inference'], False)

    # benchmark (ie when gkpgs are inputted along with imagery)
    dontcare = get_key_def("ignore_index", params["training"], -1)
    targ_ids = get_key_def('target_ids',
                           params['sample'],
                           None,
                           expected_type=List)

    # SETTING OUTPUT DIRECTORY
    working_folder = Path(
        params['inference']['state_dict_path']).parent.joinpath(
            f'inference_{num_bands}bands')
    Path.mkdir(working_folder, parents=True, exist_ok=True)

    # mlflow logging
    mlflow_uri = get_key_def('mlflow_uri',
                             params['global'],
                             default=None,
                             expected_type=str)
    if mlflow_uri and not Path(mlflow_uri).is_dir():
        warnings.warn(f'Mlflow uri path is not valid: {mlflow_uri}')
        mlflow_uri = None
    # SETUP LOGGING
    import logging.config  # See: https://docs.python.org/2.4/lib/logging-config-fileformat.html
    if mlflow_uri:
        log_config_path = Path('utils/logging.conf').absolute()
        logfile = f'{working_folder}/info.log'
        logfile_debug = f'{working_folder}/debug.log'
        console_level_logging = 'INFO' if not debug else 'DEBUG'
        logging.config.fileConfig(log_config_path,
                                  defaults={
                                      'logfilename': logfile,
                                      'logfilename_debug': logfile_debug,
                                      'console_level': console_level_logging
                                  })

        # import only if mlflow uri is set
        from mlflow import log_params, set_tracking_uri, set_experiment, start_run, log_artifact, log_metrics
        if not Path(mlflow_uri).is_dir():
            logging.warning(
                f"Couldn't locate mlflow uri directory {mlflow_uri}. Directory will be created."
            )
            Path(mlflow_uri).mkdir()
        set_tracking_uri(mlflow_uri)
        exp_name = get_key_def('mlflow_experiment_name',
                               params['global'],
                               default='gdl-inference',
                               expected_type=str)
        set_experiment(f'{exp_name}/{working_folder.name}')
        run_name = get_key_def('mlflow_run_name',
                               params['global'],
                               default='gdl',
                               expected_type=str)
        start_run(run_name=run_name)
        log_params(params['global'])
        log_params(params['inference'])
    else:
        # set a console logger as default
        logging.basicConfig(level=logging.DEBUG)
        logging.info(
            'No logging folder set for mlflow. Logging will be limited to console'
        )

    if debug:
        logging.warning(
            f'Debug mode activated. Some debug features may mobilize extra disk space and '
            f'cause delays in execution.')

    # Assert that all items in target_ids are integers (ex.: to benchmark single-class model with multi-class labels)
    if targ_ids:
        for item in targ_ids:
            if not isinstance(item, int):
                raise ValueError(
                    f'Target id "{item}" in target_ids is {type(item)}, expected int.'
                )

    logging.info(f'Inferences will be saved to: {working_folder}\n\n')
    if not (0 <= max_used_ram <= 100):
        logging.warning(
            f'Max used ram parameter should be a percentage. Got {max_used_ram}. '
            f'Will set default value of {default_max_used_ram} %')
        max_used_ram = default_max_used_ram

    # AWS
    bucket = None
    bucket_file_cache = []
    bucket_name = get_key_def('bucket_name', params['global'])

    # list of GPU devices that are available and unused. If no GPUs, returns empty dict
    gpu_devices_dict = get_device_ids(num_devices,
                                      max_used_ram_perc=max_used_ram,
                                      max_used_perc=max_used_perc)
    if gpu_devices_dict:
        logging.info(
            f"Number of cuda devices requested: {num_devices}. Cuda devices available: {gpu_devices_dict}. "
            f"Using {list(gpu_devices_dict.keys())[0]}\n\n")
        device = torch.device(
            f'cuda:{list(range(len(gpu_devices_dict.keys())))[0]}')
    else:
        logging.warning(
            f"No Cuda device available. This process will only run on CPU")
        device = torch.device('cpu')

    # CONFIGURE MODEL
    num_classes_backgr = add_background_to_num_class(task, num_classes)
    model, loaded_checkpoint, model_name = net(model_name=model_name,
                                               num_bands=num_bands,
                                               num_channels=num_classes_backgr,
                                               dontcare_val=dontcare_val,
                                               num_devices=1,
                                               net_params=params,
                                               inference_state_dict=state_dict)
    try:
        model.to(device)
    except RuntimeError:
        logging.info(f"Unable to use device 0")
        device = torch.device(f'cuda' if gpu_devices_dict else 'cpu')
        model.to(device)

    # CREATE LIST OF INPUT IMAGES FOR INFERENCE
    list_img = list_input_images(img_dir_or_csv,
                                 bucket_name,
                                 glob_patterns=["*.tif", "*.TIF"])

    # VALIDATION: anticipate problems with imagery and label (if provided) before entering main for loop
    valid_gpkg_set = set()
    for info in tqdm(list_img, desc='Validating imagery'):
        # validate_raster(info['tif'], num_bands, meta_map)
        if 'gpkg' in info.keys(
        ) and info['gpkg'] and info['gpkg'] not in valid_gpkg_set:
            validate_num_classes(vector_file=info['gpkg'],
                                 num_classes=num_classes,
                                 attribute_name=info['attribute_name'],
                                 ignore_index=dontcare,
                                 target_ids=targ_ids)
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    logging.info('Successfully validated imagery')
    if valid_gpkg_set:
        logging.info('Successfully validated label data for benchmarking')

    if task == 'classification':
        classifier(
            params, list_img, model, device, working_folder
        )  # FIXME: why don't we load from checkpoint in classification?

    elif task == 'segmentation':
        gdf_ = []
        gpkg_name_ = []

        # TODO: Add verifications?
        if bucket:
            bucket.download_file(
                loaded_checkpoint,
                "saved_model.pth.tar")  # TODO: is this still valid?
            model, _ = load_from_checkpoint("saved_model.pth.tar", model)
        else:
            model, _ = load_from_checkpoint(loaded_checkpoint, model)
        # LOOP THROUGH LIST OF INPUT IMAGES
        for info in tqdm(list_img,
                         desc='Inferring from images',
                         position=0,
                         leave=True):
            with start_run(run_name=Path(info['tif']).name, nested=True):
                img_name = Path(info['tif']).name
                local_gpkg = Path(
                    info['gpkg']
                ) if 'gpkg' in info.keys() and info['gpkg'] else None
                gpkg_name = local_gpkg.stem if local_gpkg else None
                if bucket:
                    local_img = f"Images/{img_name}"
                    bucket.download_file(info['tif'], local_img)
                    inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]
                else:  # FIXME: else statement should support img['meta'] integration as well.
                    local_img = Path(info['tif'])
                    Path.mkdir(working_folder.joinpath(local_img.parent.name),
                               parents=True,
                               exist_ok=True)
                    inference_image = working_folder.joinpath(
                        local_img.parent.name,
                        f"{img_name.split('.')[0]}_inference.tif")
                temp_file = working_folder.joinpath(
                    local_img.parent.name, f"{img_name.split('.')[0]}.dat")
                raster = rasterio.open(local_img, 'r')
                logging.info(f'Reading original image: {raster.name}')
                inf_meta = raster.meta
                label = None
                if local_gpkg:
                    logging.info(f'Burning label as raster: {local_gpkg}')
                    local_img = clip_raster_with_gpkg(raster, local_gpkg)
                    raster.close()
                    raster = rasterio.open(local_img, 'r')
                    logging.info(f'Reading clipped image: {raster.name}')
                    inf_meta = raster.meta
                    label = vector_to_raster(
                        vector_file=local_gpkg,
                        input_image=raster,
                        out_shape=(inf_meta['height'], inf_meta['width']),
                        attribute_name=info['attribute_name'],
                        fill=0,  # background value in rasterized vector.
                        target_ids=targ_ids)
                    if debug:
                        logging.debug(
                            f'Unique values in loaded label as raster: {np.unique(label)}\n'
                            f'Shape of label as raster: {label.shape}')
                pred, gdf = segmentation(param=params,
                                         input_image=raster,
                                         label_arr=label,
                                         num_classes=num_classes_backgr,
                                         gpkg_name=gpkg_name,
                                         model=model,
                                         chunk_size=chunk_size,
                                         device=device,
                                         scale=scale,
                                         BGR_to_RGB=BGR_to_RGB,
                                         tp_mem=temp_file,
                                         debug=debug)
                if gdf is not None:
                    gdf_.append(gdf)
                    gpkg_name_.append(gpkg_name)
                if local_gpkg:
                    pixelMetrics = ComputePixelMetrics(label, pred,
                                                       num_classes_backgr)
                    log_metrics(pixelMetrics.update(pixelMetrics.iou))
                    log_metrics(pixelMetrics.update(pixelMetrics.dice))
                pred = pred[np.newaxis, :, :].astype(np.uint8)
                inf_meta.update({
                    "driver": "GTiff",
                    "height": pred.shape[1],
                    "width": pred.shape[2],
                    "count": pred.shape[0],
                    "dtype": 'uint8',
                    "compress": 'lzw'
                })
                logging.info(
                    f'Successfully inferred on {img_name}\nWriting to file: {inference_image}'
                )
                with rasterio.open(inference_image, 'w+', **inf_meta) as dest:
                    dest.write(pred)
                del pred
                try:
                    temp_file.unlink()
                except OSError as e:
                    logging.warning(f'File Error: {temp_file, e.strerror}')
                if raster_to_vec:
                    start_vec = time.time()
                    inference_vec = working_folder.joinpath(
                        local_img.parent.name,
                        f"{img_name.split('.')[0]}_inference.gpkg")
                    ras2vec(inference_image, inference_vec)
                    end_vec = time.time() - start_vec
                    logging.info(
                        'Vectorization completed in {:.0f}m {:.0f}s'.format(
                            end_vec // 60, end_vec % 60))

        if len(gdf_) >= 1:
            if not len(gdf_) == len(gpkg_name_):
                raise ValueError('benchmarking unable to complete')
            all_gdf = pd.concat(
                gdf_)  # Concatenate all geo data frame into one geo data frame
            all_gdf.reset_index(drop=True, inplace=True)
            gdf_x = gpd.GeoDataFrame(all_gdf)
            bench_gpkg = working_folder / "benchmark.gpkg"
            gdf_x.to_file(bench_gpkg, driver="GPKG", index=False)
            logging.info(
                f'Successfully wrote benchmark geopackage to: {bench_gpkg}')
        # log_artifact(working_folder)
    time_elapsed = time.time() - since
    logging.info('Inference Script completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Exemple #9
0
                        '--param',
                        metavar='yaml_file',
                        nargs=1,
                        help='Path to parameters stored in yaml')
    parser.add_argument('-i',
                        '--input',
                        metavar='model_pth img_dir',
                        nargs=2,
                        help='model_path and image_dir')
    args = parser.parse_args()

    # if a yaml is inputted, get those parameters and get model state_dict to overwrite global parameters afterwards
    if args.param:
        input_params = read_parameters(args.param[0])
        model_ckpt = get_key_def('state_dict_path',
                                 input_params['inference'],
                                 expected_type=str)
        # load checkpoint
        checkpoint = load_checkpoint(model_ckpt)
        if 'params' in checkpoint.keys():
            params = checkpoint['params']
            # overwrite with inputted parameters
            compare_config_yamls(yaml1=params,
                                 yaml2=input_params,
                                 update_yaml1=True)
        else:
            warnings.warn(
                'No parameters found in checkpoint. Defaulting to parameters from inputted yaml.'
                'Use GDL version 1.3 or more.')
            params = input_params
        del checkpoint
Exemple #10
0
def create_csv():
    """
    Creates samples from the input images for the pixel_inventory function

    """
    prep_csv_path = params['sample']['prep_csv_file']
    dist_samples = params['sample']['samples_dist']
    sample_size = params['global']['samples_size']
    data_path = params['global']['data_path']
    Path.mkdir(Path(data_path), exist_ok=True)
    num_classes = params['global']['num_classes']
    data_prep_csv = read_csv(prep_csv_path)

    csv_prop_data = params['global']['data_path'] + '/prop_data.csv'
    if os.path.isfile(csv_prop_data):
        os.remove(csv_prop_data)

    with tqdm(data_prep_csv) as _tqdm:
        for info in _tqdm:

            _tqdm.set_postfix(OrderedDict(file=f'{info["tif"]}', sample_size=params['global']['samples_size']))

            # Validate the number of class in the vector file
            validate_num_classes(info['gpkg'], num_classes, info['attribute_name'])

            assert os.path.isfile(info['tif']), f"could not open raster file at {info['tif']}"
            with rasterio.open(info['tif'], 'r') as raster:

                # Burn vector file in a raster file
                np_label_raster = vector_to_raster(vector_file=info['gpkg'],
                                                   input_image=raster,
                                                   attribute_name=info['attribute_name'],
                                                   fill=get_key_def('ignore_idx', get_key_def('training', params, {}),
                                                                    0))

                # Read the input raster image
                np_input_image = image_reader_as_array(input_image=raster,
                                                       aux_vector_file=get_key_def('aux_vector_file', params['global'],
                                                                                   None),
                                                       aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                                                     params['global'], None),
                                                       aux_vector_ids=get_key_def('aux_vector_ids', params['global'],
                                                                                  None),
                                                       aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                                                        params['global'], True),
                                                       aux_vector_dist_log=get_key_def('aux_vector_dist_log',
                                                                                       params['global'], True),
                                                       aux_vector_scale=get_key_def('aux_vector_scale',
                                                                                    params['global'], None))
                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = images_to_samples.mask_image(np_input_image, np_label_raster)

                np_label_raster = np.reshape(np_label_raster, (np_label_raster.shape[0], np_label_raster.shape[1], 1))

                h, w, num_bands = np_input_image.shape

                # half tile padding
                half_tile = int(sample_size / 2)
                pad_label_array = np.pad(np_label_raster, ((half_tile, half_tile), (half_tile, half_tile), (0, 0)),
                                         mode='constant')

                for row in range(0, h, dist_samples):
                    for column in range(0, w, dist_samples):
                        target = np.squeeze(pad_label_array[row:row + sample_size, column:column + sample_size, :], axis=2)

                        pixel_inventory(target, sample_size, params['global']['num_classes'] + 1,
                                        params['global']['data_path'], info['dataset'])
Exemple #11
0
def main(params):
    """
    Training and validation datasets preparation.

    Process
    -------
    1. Read csv file and validate existence of all input files and GeoPackages.

    2. Do the following verifications:
        1. Assert number of bands found in raster is equal to desired number
           of bands.
        2. Check that `num_classes` is equal to number of classes detected in
           the specified attribute for each GeoPackage.
           Warning: this validation will not succeed if a Geopackage
                    contains only a subset of `num_classes` (e.g. 3 of 4).
        3. Assert Coordinate reference system between raster and gpkg match.

    3. Read csv file and for each line in the file, do the following:
        1. Read input image as array with utils.readers.image_reader_as_array().
            - If gpkg's extent is smaller than raster's extent,
              raster is clipped to gpkg's extent.
            - If gpkg's extent is bigger than raster's extent,
              gpkg is clipped to raster's extent.
        2. Convert GeoPackage vector information into the "label" raster with
           utils.utils.vector_to_raster(). The pixel value is determined by the
           attribute in the csv file.
        3. Create a new raster called "label" with the same properties as the
           input image.
        4. Read metadata and add to input as new bands (*more details to come*).
        5. Crop the arrays in smaller samples of the size `samples_size` of
           `your_conf.yaml`. Visual representation of this is provided at
            https://medium.com/the-downlinq/broad-area-satellite-imagery-semantic-segmentation-basiss-4a7ea2c8466f
        6. Write samples from input image and label into the "val", "trn" or
           "tst" hdf5 file, depending on the value contained in the csv file.
            Refer to samples_preparation().

    -------
    :param params: (dict) Parameters found in the yaml config file.
    """
    params['global']['git_hash'] = get_git_hash()
    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    bucket_file_cache = []

    assert params['global'][
        'task'] == 'segmentation', f"images_to_samples.py isn't necessary when performing classification tasks"

    # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS.
    bucket_name = get_key_def('bucket_name', params['global'])
    data_path = Path(params['global']['data_path'])
    Path.mkdir(data_path, exist_ok=True, parents=True)
    csv_file = params['sample']['prep_csv_file']
    val_percent = params['sample']['val_percent']
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = get_key_def('min_annotated_percent',
                                 params['sample']['sampling_method'],
                                 None,
                                 expected_type=int)
    num_bands = params['global']['number_of_bands']
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activate. Execution may take longer...')

    final_samples_folder = None

    sample_path_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'

    # AWS
    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv')
        if data_path:
            final_samples_folder = data_path.joinpath("samples")
        else:
            final_samples_folder = "samples"
        samples_folder = sample_path_name

    else:
        list_data_prep = read_csv(csv_file)
        samples_folder = data_path.joinpath(sample_path_name)

    if samples_folder.is_dir():
        warnings.warn(
            f'Data path exists: {samples_folder}. Suffix will be added to directory name.'
        )
        samples_folder = Path(str(samples_folder) + '_' + now)
    else:
        tqdm.write(f'Writing samples to {samples_folder}')
    Path.mkdir(samples_folder, exist_ok=False
               )  # TODO: what if we want to append samples to existing hdf5?
    tqdm.write(f'Samples will be written to {samples_folder}\n\n')

    tqdm.write(f'\nSuccessfully read csv file: {Path(csv_file).stem}\n'
               f'Number of rows: {len(list_data_prep)}\n'
               f'Copying first entry:\n{list_data_prep[0]}\n')
    ignore_index = get_key_def('ignore_index', params['training'], -1)
    meta_map, metadata = get_key_def("meta_map", params["global"], {}), None

    # VALIDATION: (1) Assert num_classes parameters == num actual classes in gpkg and (2) check CRS match (tif and gpkg)
    valid_gpkg_set = set()
    for info in tqdm(list_data_prep, position=0):
        assert_num_bands(info['tif'], num_bands, meta_map)
        if info['gpkg'] not in valid_gpkg_set:
            gpkg_classes = validate_num_classes(
                info['gpkg'], params['global']['num_classes'],
                info['attribute_name'], ignore_index)
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    if debug:
        # VALIDATION (debug only): Checking validity of features in vector files
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Checking validity of features in vector files"):
            invalid_features = validate_features_from_gpkg(
                info['gpkg'], info['attribute_name']
            )  # TODO: test this with invalid features.
            assert not invalid_features, f"{info['gpkg']}: Invalid geometry object(s) '{invalid_features}'"

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    class_prop = get_key_def('class_proportion',
                             params['sample']['sampling_method'],
                             None,
                             expected_type=dict)

    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        params, samples_folder)

    # Set dontcare (aka ignore_index) value
    dontcare = get_key_def(
        "ignore_index", params["training"],
        -1)  # TODO: deduplicate with train_segmentation, l300
    if dontcare == 0:
        warnings.warn(
            "The 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero;"
            " all valid class indices should be consecutive, and start at 0. The 'dontcare' value"
            " will be remapped to -1 while loading the dataset, and inside the config from now on."
        )
        params["training"]["ignore_index"] = -1

    # creates pixel_classes dict and keys
    pixel_classes = {key: 0 for key in gpkg_classes}
    background_val = 0
    pixel_classes[background_val] = 0
    class_prop = validate_class_prop_dict(pixel_classes, class_prop)
    pixel_classes[dontcare] = 0

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    with tqdm(list_data_prep,
              position=0,
              leave=False,
              desc=f'Preparing samples') as _tqdm:
        for info in _tqdm:
            _tqdm.set_postfix(
                OrderedDict(tif=f'{Path(info["tif"]).stem}',
                            sample_size=params['global']['samples_size']))
            try:
                if bucket_name:
                    bucket.download_file(
                        info['tif'], "Images/" + info['tif'].split('/')[-1])
                    info['tif'] = "Images/" + info['tif'].split('/')[-1]
                    if info['gpkg'] not in bucket_file_cache:
                        bucket_file_cache.append(info['gpkg'])
                        bucket.download_file(info['gpkg'],
                                             info['gpkg'].split('/')[-1])
                    info['gpkg'] = info['gpkg'].split('/')[-1]
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]

                with rasterio.open(info['tif'], 'r') as raster:
                    # 1. Read the input raster image
                    np_input_image, raster, dataset_nodata = image_reader_as_array(
                        input_image=raster,
                        clip_gpkg=info['gpkg'],
                        aux_vector_file=get_key_def('aux_vector_file',
                                                    params['global'], None),
                        aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                      params['global'], None),
                        aux_vector_ids=get_key_def('aux_vector_ids',
                                                   params['global'], None),
                        aux_vector_dist_maps=get_key_def(
                            'aux_vector_dist_maps', params['global'], True),
                        aux_vector_dist_log=get_key_def(
                            'aux_vector_dist_log', params['global'], True),
                        aux_vector_scale=get_key_def('aux_vector_scale',
                                                     params['global'], None))

                    # 2. Burn vector file in a raster file
                    np_label_raster = vector_to_raster(
                        vector_file=info['gpkg'],
                        input_image=raster,
                        out_shape=np_input_image.shape[:2],
                        attribute_name=info['attribute_name'],
                        fill=background_val
                    )  # background value in rasterized vector.

                    if dataset_nodata is not None:
                        # 3. Set ignore_index value in label array where nodata in raster (only if nodata across all bands)
                        np_label_raster[dataset_nodata] = dontcare

                if debug:
                    out_meta = raster.meta.copy()
                    np_image_debug = np_input_image.transpose(2, 0, 1).astype(
                        out_meta['dtype'])
                    out_meta.update({
                        "driver": "GTiff",
                        "height": np_image_debug.shape[1],
                        "width": np_image_debug.shape[2]
                    })
                    out_tif = samples_folder / f"np_input_image_{_tqdm.n}.tif"
                    print(f"DEBUG: writing clipped raster to {out_tif}")
                    with rasterio.open(out_tif, "w", **out_meta) as dest:
                        dest.write(np_image_debug)

                    out_meta = raster.meta.copy()
                    np_label_debug = np.expand_dims(
                        np_label_raster,
                        axis=2).transpose(2, 0, 1).astype(out_meta['dtype'])
                    out_meta.update({
                        "driver": "GTiff",
                        "height": np_label_debug.shape[1],
                        "width": np_label_debug.shape[2],
                        'count': 1
                    })
                    out_tif = samples_folder / f"np_label_rasterized_{_tqdm.n}.tif"
                    print(f"DEBUG: writing final rasterized gpkg to {out_tif}")
                    with rasterio.open(out_tif, "w", **out_meta) as dest:
                        dest.write(np_label_debug)

                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = mask_image(np_input_image,
                                                 np_label_raster)

                if info['dataset'] == 'trn':
                    out_file = trn_hdf5
                elif info['dataset'] == 'tst':
                    out_file = tst_hdf5
                else:
                    raise ValueError(
                        f"Dataset value must be trn or tst. Provided value is {info['dataset']}"
                    )
                val_file = val_hdf5

                metadata = add_metadata_from_raster_to_sample(
                    sat_img_arr=np_input_image,
                    raster_handle=raster,
                    meta_map=meta_map,
                    raster_info=info)
                # Save label's per class pixel count to image metadata
                metadata['source_label_bincount'] = {
                    class_num: count
                    for class_num, count in enumerate(
                        np.bincount(np_label_raster.clip(min=0).flatten()))
                    if count > 0
                }  # TODO: add this to add_metadata_from[...] function?

                np_label_raster = np.reshape(
                    np_label_raster,
                    (np_label_raster.shape[0], np_label_raster.shape[1], 1))
                # 3. Prepare samples!
                number_samples, number_classes = samples_preparation(
                    in_img_array=np_input_image,
                    label_array=np_label_raster,
                    sample_size=samples_size,
                    overlap=overlap,
                    samples_count=number_samples,
                    num_classes=number_classes,
                    samples_file=out_file,
                    val_percent=val_percent,
                    val_sample_file=val_file,
                    dataset=info['dataset'],
                    pixel_classes=pixel_classes,
                    image_metadata=metadata,
                    dontcare=dontcare,
                    min_annot_perc=min_annot_perc,
                    class_prop=class_prop)

                _tqdm.set_postfix(OrderedDict(number_samples=number_samples))
                out_file.flush()
            except OSError as e:
                warnings.warn(
                    f'An error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                    f'{Path(info["gpkg"]).stem} (gpkg). Error: "{e}"')
                continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]

    # prints the proportion of pixels of each class for the samples created
    for i in pixel_classes:
        prop = round((pixel_classes[i] / pixel_total) *
                     100, 1) if pixel_total > 0 else 0
        print('Pixels from class', i, ':', prop, '%')

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")
def main(params, config_path):
    """
    Function to train and validate a models for semantic segmentation or classification.
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.

    """
    # MANDATORY PARAMETERS
    num_classes = get_key_def('num_classes', params['global'], expected_type=int)
    num_bands = get_key_def('number_of_bands', params['global'], expected_type=int)
    batch_size = get_key_def('batch_size', params['training'], expected_type=int)
    num_epochs = get_key_def('num_epochs', params['training'], expected_type=int)
    model_name = get_key_def('model_name', params['global'], expected_type=str).lower()
    BGR_to_RGB = get_key_def('BGR_to_RGB', params['global'], expected_type=bool)

    # parameters to find hdf5 samples
    data_path = Path(get_key_def('data_path', params['global'], './data', expected_type=str))
    assert data_path.is_dir(), f'Could not locate data path {data_path}'

    # OPTIONAL PARAMETERS
    # basics
    debug = get_key_def('debug_mode', params['global'], default=False, expected_type=bool)
    task = get_key_def('task', params['global'], default='classification', expected_type=str)
    assert task == 'classification', f"The task should be classification. The provided value is {task}"
    dontcare_val = get_key_def("ignore_index", params["training"], default=-1, expected_type=int)
    batch_metrics = get_key_def('batch_metrics', params['training'], default=1, expected_type=int)
    meta_map = get_key_def("meta_map", params["global"], default={})
    bucket_name = get_key_def('bucket_name', params['global'])  # AWS

    # model params
    loss_fn = get_key_def('loss_fn', params['training'], default='CrossEntropy', expected_type=str)
    optimizer = get_key_def('optimizer', params['training'], default='adam', expected_type=str)
    pretrained = get_key_def('pretrained', params['training'], default=True, expected_type=bool)
    train_state_dict_path = get_key_def('state_dict_path', params['training'], default=None, expected_type=str)
    dropout_prob = get_key_def('dropout_prob', params['training'], default=None, expected_type=float)

    # gpu parameters
    num_devices = get_key_def('num_gpus', params['global'], default=0, expected_type=int)
    max_used_ram = get_key_def('max_used_ram', params['global'], default=2000, expected_type=int)
    max_used_perc = get_key_def('max_used_perc', params['global'], default=15, expected_type=int)

    # automatic model naming with unique id for each training
    model_id = config_path.stem
    output_path = data_path.joinpath('model') / model_id
    if output_path.is_dir():
        last_mod_time_suffix = datetime.fromtimestamp(output_path.stat().st_mtime).strftime('%Y%m%d-%H%M%S')
        archive_output_path = data_path.joinpath('model') / f"{model_id}_{last_mod_time_suffix}"
        shutil.move(output_path, archive_output_path)
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))  # copy yaml to output path where model will be saved
    tqdm.write(f'Model and log files will be saved to: {output_path}\n\n')

    if debug:
        warnings.warn(f'Debug mode activated. Some debug functions may cause delays in execution.')

    if bucket_name:
        bucket, bucket_output_path, output_path, data_path = download_s3_files(bucket_name=bucket_name,
                                                                               data_path=data_path,
                                                                               output_path=output_path,
                                                                               num_classes=num_classes)
    elif not bucket_name:
        get_local_classes(num_classes, data_path, output_path)

    since = time.time()
    now = datetime.now().strftime("%Y-%m-%d_%H-%M")
    best_loss = 999

    progress_log = Path(output_path) / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep', 'time'))  # Add header

    trn_log = InformationLogger('trn')
    val_log = InformationLogger('val')
    tst_log = InformationLogger('tst')

    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    gpu_devices_dict = get_device_ids(num_devices,
                                      max_used_ram_perc=max_used_ram,
                                      max_used_perc=max_used_perc)
    num_devices = len(gpu_devices_dict.keys())
    device = torch.device(f'cuda:0' if gpu_devices_dict else 'cpu')

    tqdm.write(f'Creating dataloaders from data in {Path(data_path)}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_classif_dataloader(data_path=data_path,
                                                                               batch_size=batch_size,
                                                                               num_devices=num_devices,)

    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, model_name, criterion, optimizer, lr_scheduler = net(model_name=model_name,
                                                                num_bands=num_bands,
                                                                num_channels=num_classes,
                                                                dontcare_val=dontcare_val,
                                                                num_devices=num_devices,
                                                                train_state_dict_path=train_state_dict_path,
                                                                pretrained=pretrained,
                                                                dropout_prob=dropout_prob,
                                                                loss_fn=loss_fn,
                                                                optimizer=optimizer,
                                                                net_params=params)
    tqdm.write(f'Instantiated {model_name} model with {num_classes} output channels.\n')

    filename = os.path.join(output_path, 'checkpoint.pth.tar')

    for epoch in range(0, params['training']['num_epochs']):
        logging.info(f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}')

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes,
                           batch_size=batch_size,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           device=device,
                           debug=debug)
        trn_log.add_values(trn_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(eval_loader=val_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes,
                                batch_size=batch_size,
                                ep_idx=epoch,
                                progress_log=progress_log,
                                batch_metrics=params['training']['batch_metrics'],
                                dataset='val',
                                device=device,
                                debug=debug)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch, ignore=['iou'])
        else:
            val_log.add_values(val_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            tqdm.write("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict() if num_devices > 1 else model.state_dict()
            torch.save({'epoch': epoch,
                        'params': params,
                        'model': state_dict,
                        'best_loss': best_loss,
                        'optimizer': optimizer.state_dict()}, filename)

            if bucket_name:
                bucket_filename = os.path.join(bucket_output_path, 'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, batch_metrics)

        cur_elapsed = time.time() - since
        logging.info(f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s')

    # load checkpoint model and evaluate it on test dataset.
    if int(params['training']['num_epochs']) > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(eval_loader=tst_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes,
                                batch_size=batch_size,
                                ep_idx=num_epochs,
                                progress_log=progress_log,
                                batch_metrics=batch_metrics,
                                dataset='tst',
                                device=device)
        tst_log.add_values(tst_report, num_epochs, ignore=['iou'])

        if bucket_name:
            bucket_filename = os.path.join(bucket_output_path, 'last_epoch.pth.tar')
            bucket.upload_file("output.txt", os.path.join(bucket_output_path, f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Exemple #13
0
def main(params: dict):
    """
    Identify the class to which each image belongs.
    :param params: (dict) Parameters found in the yaml config file.

    """
    # SET BASIC VARIABLES AND PATHS
    since = time.time()
    task = params['global']['task']
    img_dir_or_csv = params['inference']['img_dir_or_csv_file']
    chunk_size = get_key_def('chunk_size', params['inference'], 512)
    prediction_with_smoothing = get_key_def('smooth_prediction',
                                            params['inference'], False)
    overlap = get_key_def('overlap', params['inference'], 2)
    num_classes = params['global']['num_classes']
    num_classes_corrected = add_background_to_num_class(task, num_classes)
    num_bands = params['global']['number_of_bands']
    working_folder = Path(
        params['inference']['state_dict_path']).parent.joinpath(
            f'inference_{num_bands}bands')
    num_devices = params['global']['num_gpus'] if params['global'][
        'num_gpus'] else 0
    Path.mkdir(working_folder, parents=True, exist_ok=True)
    print(f'Inferences will be saved to: {working_folder}\n\n')

    bucket = None
    bucket_file_cache = []
    bucket_name = get_key_def('bucket_name', params['global'])

    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(
        num_devices) if torch.cuda.is_available() else []
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.
                          is_available() and lst_device_ids else 'cpu')

    if lst_device_ids:
        print(
            f"Number of cuda devices requested: {num_devices}. Cuda devices available: {lst_device_ids}. Using {lst_device_ids[0]}\n\n"
        )
    else:
        warnings.warn(
            f"No Cuda device available. This process will only run on CPU")

    # CONFIGURE MODEL
    model, state_dict_path, model_name = net(
        params, num_channels=num_classes_corrected, inference=True)

    try:
        model.to(device)
    except RuntimeError:
        print(f"Unable to use device. Trying device 0")
        device = torch.device(f'cuda:0' if torch.cuda.is_available()
                              and lst_device_ids else 'cpu')
        model.to(device)

    # mlflow tracking path + parameters logging
    set_tracking_uri(
        get_key_def('mlflow_uri', params['global'], default="./mlruns"))
    set_experiment('gdl-benchmarking/' + working_folder.name)
    log_params(params['global'])
    log_params(params['inference'])

    # CREATE LIST OF INPUT IMAGES FOR INFERENCE
    list_img = list_input_images(img_dir_or_csv,
                                 bucket_name,
                                 glob_patterns=["*.tif", "*.TIF"])

    if task == 'classification':
        # FIXME: why don't we load from checkpoint in classification?
        classifier(params, list_img, model, device, working_folder)

    elif task == 'segmentation':
        # TODO: Add verifications?
        if bucket:
            bucket.download_file(
                state_dict_path,
                "saved_model.pth.tar")  # TODO: is this still valid?
            model, _ = load_from_checkpoint("saved_model.pth.tar", model)
        else:
            model, _ = load_from_checkpoint(state_dict_path, model)
        # LOOP THROUGH LIST OF INPUT IMAGES
        with tqdm(list_img, desc='image list', position=0) as _tqdm:
            for info in _tqdm:
                img_name = Path(info['tif']).name
                local_gpkg = info['gpkg']
                if local_gpkg:
                    local_gpkg = Path(local_gpkg)
                if bucket:
                    local_img = f"Images/{img_name}"
                    bucket.download_file(info['tif'], local_img)
                    inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]
                else:  # FIXME: else statement should support img['meta'] integration as well.
                    local_img = Path(info['tif'])
                    inference_image = working_folder.joinpath(
                        f"{img_name.split('.')[0]}_inference.tif")
                    print(inference_image)
                assert local_img.is_file(
                ), f"Could not locate raster file at {local_img}"
                with rasterio.open(local_img, 'r') as raster:
                    inf_meta = raster.meta
                    if prediction_with_smoothing:
                        print('Smoothening Predictions with 2D interpolation')
                        pred = segmentation_with_smoothing(
                            raster, local_gpkg, model, chunk_size, overlap,
                            num_bands, device)
                    else:
                        pred = segmentation(raster, local_gpkg, model,
                                            chunk_size, num_bands, device)
                    if local_gpkg:
                        assert local_gpkg.is_file(
                        ), f"Could not locate gkpg file at {local_gpkg}"
                        label = vector_to_raster(
                            vector_file=local_gpkg,
                            input_image=raster,
                            out_shape=pred.shape[:2],
                            attribute_name=info['attribute_name'],
                            fill=0)  # background value in rasterized vector.
                        with start_run(run_name=img_name, nested=True):
                            pixelMetrics = ComputePixelMetrics(
                                label, pred, num_classes_corrected)
                            log_metrics(
                                pixelMetrics.update(pixelMetrics.jaccard))
                            log_metrics(pixelMetrics.update(pixelMetrics.dice))
                            log_metrics(
                                pixelMetrics.update(pixelMetrics.accuracy))
                            log_metrics(
                                pixelMetrics.update(pixelMetrics.precision))
                            log_metrics(
                                pixelMetrics.update(pixelMetrics.recall))
                            log_metrics(
                                pixelMetrics.update(pixelMetrics.matthews))

                        label_classes = np.unique(label)
                        assert len(colors) >= len(
                            label_classes
                        ), f'Not enough colors and class names for number of classes in output'
                        # FIXME: color mapping scheme is hardcoded for now because of memory constraint; To be fixed.
                        label_rgb = ind2rgb(label, colors)
                        pred_rgb = ind2rgb(pred, colors)
                        Image.fromarray((label_rgb).astype(np.uint8),
                                        mode='RGB').save(
                                            os.path.join(
                                                working_folder, 'label_rgb_' +
                                                inference_image.stem + '.png'))
                        Image.fromarray((pred_rgb).astype(np.uint8),
                                        mode='RGB').save(
                                            os.path.join(
                                                working_folder, 'pred_rgb_' +
                                                inference_image.stem + '.png'))
                        del label_rgb, pred_rgb
                    pred = pred[np.newaxis, :, :]
                    inf_meta.update({
                        "driver": "GTiff",
                        "height": pred.shape[1],
                        "width": pred.shape[2],
                        "count": pred.shape[0],
                        "dtype": 'uint8'
                    })

                    with rasterio.open(inference_image, 'w+',
                                       **inf_meta) as dest:
                        dest.write(pred)
        log_artifact(working_folder)
    time_elapsed = time.time() - since
    print('Inference and Benchmarking completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Exemple #14
0
def main(params):
    """
    Identify the class to which each image belongs.
    :param params: (dict) Parameters found in the yaml config file.

    """
    # SET BASIC VARIABLES AND PATHS
    since = time.time()

    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.')

    num_classes = params['global']['num_classes']
    if params['global']['task'] == 'segmentation':
        # assume background is implicitly needed (makes no sense to predict with one class, for example.)
        # this will trigger some warnings elsewhere, but should succeed nonetheless
        num_classes_corrected = num_classes + 1 # + 1 for background # FIXME temporary patch for num_classes problem.
    elif params['global']['task'] == 'classification':
        num_classes_corrected = num_classes

    chunk_size = get_key_def('chunk_size', params['inference'], 512)
    overlap = get_key_def('overlap', params['inference'], 10)
    nbr_pix_overlap = int(math.floor(overlap / 100 * chunk_size))
    num_bands = params['global']['number_of_bands']

    img_dir_or_csv = params['inference']['img_dir_or_csv_file']

    default_working_folder = Path(params['inference']['state_dict_path']).parent.joinpath(f'inference_{num_bands}bands')
    working_folder = Path(get_key_def('working_folder', params['inference'], default_working_folder)) # TODO: remove working_folder parameter in all templates
    Path.mkdir(working_folder, exist_ok=True)
    print(f'Inferences will be saved to: {working_folder}\n\n')

    bucket = None
    bucket_file_cache = []
    bucket_name = params['global']['bucket_name']

    # CONFIGURE MODEL
    model, state_dict_path, model_name = net(params, num_channels=num_classes_corrected, inference=True)

    num_devices = params['global']['num_gpus'] if params['global']['num_gpus'] else 0
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(num_devices) if torch.cuda.is_available() else []
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.is_available() and lst_device_ids else 'cpu')

    if lst_device_ids:
        print(f"Number of cuda devices requested: {num_devices}. Cuda devices available: {lst_device_ids}. Using {lst_device_ids[0]}\n\n")
    else:
        warnings.warn(f"No Cuda device available. This process will only run on CPU")

    try:
        model.to(device)
    except RuntimeError:
        print(f"Unable to use device. Trying device 0")
        device = torch.device(f'cuda:0' if torch.cuda.is_available() and lst_device_ids else 'cpu')
        model.to(device)

    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        if img_dir_or_csv.endswith('.csv'):
            bucket.download_file(img_dir_or_csv, 'img_csv_file.csv')
            list_img = read_csv('img_csv_file.csv', inference=True)
        else:
            raise NotImplementedError(
                'Specify a csv file containing images for inference. Directory input not implemented yet')
    else:
        if img_dir_or_csv.endswith('.csv'):
            list_img = read_csv(img_dir_or_csv, inference=True)
        else:
            img_dir = Path(img_dir_or_csv)
            assert img_dir.is_dir(), f'Could not find directory "{img_dir_or_csv}"'
            list_img_paths = sorted(img_dir.glob('*.tif'))  # FIXME: what if .tif is in caps (.TIF) ?
            list_img = []
            for img_path in list_img_paths:
                img = {}
                img['tif'] = img_path
                list_img.append(img)
            assert len(list_img) >= 0, f'No .tif files found in {img_dir_or_csv}'

    if params['global']['task'] == 'classification':
        classifier(params, list_img, model, device, working_folder)  # FIXME: why don't we load from checkpoint in classification?

    elif params['global']['task'] == 'segmentation':
        if bucket:
            bucket.download_file(state_dict_path, "saved_model.pth.tar")
            model, _ = load_from_checkpoint("saved_model.pth.tar", model, inference=True)
        else:
            model, _ = load_from_checkpoint(state_dict_path, model, inference=True)

        with tqdm(list_img, desc='image list', position=0) as _tqdm:
            for img in _tqdm:
                img_name = Path(img['tif']).name
                if bucket:
                    local_img = f"Images/{img_name}"
                    bucket.download_file(img['tif'], local_img)
                    inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                    if img['meta']:
                        if img['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(img['meta'])
                            bucket.download_file(img['meta'], img['meta'].split('/')[-1])
                        img['meta'] = img['meta'].split('/')[-1]
                else:
                    local_img = Path(img['tif'])
                    inference_image = working_folder.joinpath(f"{img_name.split('.')[0]}_inference.tif")

                assert local_img.is_file(), f"Could not open raster file at {local_img}"

                scale = get_key_def('scale_data', params['global'], None)
                with rasterio.open(local_img, 'r') as raster:

                    np_input_image = image_reader_as_array(input_image=raster,
                                                           scale=scale,
                                                           aux_vector_file=get_key_def('aux_vector_file',
                                                                                       params['global'], None),
                                                           aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                                                         params['global'], None),
                                                           aux_vector_ids=get_key_def('aux_vector_ids',
                                                                                      params['global'], None),
                                                           aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                                                            params['global'], True),
                                                           aux_vector_scale=get_key_def('aux_vector_scale',
                                                                                        params['global'], None))

                meta_map, metadata = get_key_def("meta_map", params["global"], {}), None
                if meta_map:
                    assert img['meta'] is not None and isinstance(img['meta'], str) and os.path.isfile(img['meta']), \
                        "global configuration requested metadata mapping onto loaded samples, but raster did not have available metadata"
                    metadata = read_parameters(img['meta'])

                if debug:
                    _tqdm.set_postfix(OrderedDict(img_name=img_name,
                                                  img=np_input_image.shape,
                                                  img_min_val=np.min(np_input_image),
                                                  img_max_val=np.max(np_input_image)))

                input_band_count = np_input_image.shape[2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                if input_band_count > params['global']['number_of_bands']:
                    # FIXME: Following statements should be reconsidered to better manage inconsistencies between
                    #  provided number of band and image number of band.
                    warnings.warn(f"Input image has more band than the number provided in the yaml file ({params['global']['number_of_bands']}). "
                                  f"Will use the first {params['global']['number_of_bands']} bands of the input image.")
                    np_input_image = np_input_image[:, :, 0:params['global']['number_of_bands']]
                    print(f"Input image's new shape: {np_input_image.shape}")

                elif input_band_count < params['global']['number_of_bands']:
                    warnings.warn(f"Skipping image: The number of bands requested in the yaml file ({params['global']['number_of_bands']})"
                                  f"can not be larger than the number of band in the input image ({input_band_count}).")
                    continue

                # START INFERENCES ON SUB-IMAGES
                sem_seg_results_per_class = sem_seg_inference(model, np_input_image, nbr_pix_overlap, chunk_size, num_classes_corrected,
                                                    device, meta_map, metadata, output_path=working_folder, index=_tqdm.n, debug=debug)

                # CREATE GEOTIF FROM METADATA OF ORIGINAL IMAGE
                tqdm.write(f'Saving inference...\n')
                if get_key_def('heatmaps', params['inference'], False):
                    tqdm.write(f'Heatmaps will be saved.\n')
                vis(params, np_input_image, sem_seg_results_per_class, working_folder, inference_input_path=local_img, debug=debug)

                tqdm.write(f"\n\nSemantic segmentation of image {img_name} completed\n\n")
                if bucket:
                    bucket.upload_file(inference_image, os.path.join(working_folder, f"{img_name.split('.')[0]}_inference.tif"))
    else:
        raise ValueError(
            f"The task should be either classification or segmentation. The provided value is {params['global']['task']}")

    time_elapsed = time.time() - since
    print('Inference completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
def main(params):
    """
    Training and validation datasets preparation.
    :param params: (dict) Parameters found in the yaml config file.

    """
    bucket_file_cache = []
    bucket_name = params['global']['bucket_name']
    data_path = params['global']['data_path']
    Path.mkdir(Path(data_path), exist_ok=True)
    csv_file = params['sample']['prep_csv_file']

    final_samples_folder = None
    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv')
        if data_path:
            final_samples_folder = os.path.join(data_path, "samples")
        else:
            final_samples_folder = "samples"
        samples_folder = "samples"
        out_label_folder = "label"

    else:
        list_data_prep = read_csv(csv_file)
        samples_folder = os.path.join(
            data_path, "samples")  #FIXME check that data_path exists!
        out_label_folder = os.path.join(data_path, "label")

    create_or_empty_folder(samples_folder)
    create_or_empty_folder(out_label_folder)

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        params, samples_folder)

    with tqdm(list_data_prep) as _tqdm:
        for info in _tqdm:

            if bucket_name:
                bucket.download_file(info['tif'],
                                     "Images/" + info['tif'].split('/')[-1])
                info['tif'] = "Images/" + info['tif'].split('/')[-1]
                if info['gpkg'] not in bucket_file_cache:
                    bucket_file_cache.append(info['gpkg'])
                    bucket.download_file(info['gpkg'],
                                         info['gpkg'].split('/')[-1])
                info['gpkg'] = info['gpkg'].split('/')[-1]
                if info['meta']:
                    if info['meta'] not in bucket_file_cache:
                        bucket_file_cache.append(info['meta'])
                        bucket.download_file(info['meta'],
                                             info['meta'].split('/')[-1])
                    info['meta'] = info['meta'].split('/')[-1]

            _tqdm.set_postfix(
                OrderedDict(file=f'{info["tif"]}',
                            sample_size=params['global']['samples_size']))

            # Validate the number of class in the vector file
            validate_num_classes(info['gpkg'], params['global']['num_classes'],
                                 info['attribute_name'])

            assert os.path.isfile(
                info['tif']), f"could not open raster file at {info['tif']}"
            with rasterio.open(info['tif'], 'r') as raster:

                # Burn vector file in a raster file
                np_label_raster = vector_to_raster(
                    vector_file=info['gpkg'],
                    input_image=raster,
                    attribute_name=info['attribute_name'],
                    fill=get_key_def('ignore_idx',
                                     get_key_def('training', params, {}), 0))

                # Read the input raster image
                np_input_image = image_reader_as_array(
                    input_image=raster,
                    scale=get_key_def('scale_data', params['global'], None),
                    aux_vector_file=get_key_def('aux_vector_file',
                                                params['global'], None),
                    aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                  params['global'], None),
                    aux_vector_ids=get_key_def('aux_vector_ids',
                                               params['global'], None),
                    aux_vector_dist_maps=get_key_def('aux_vector_dist_maps',
                                                     params['global'], True),
                    aux_vector_dist_log=get_key_def('aux_vector_dist_log',
                                                    params['global'], True),
                    aux_vector_scale=get_key_def('aux_vector_scale',
                                                 params['global'], None))

            # Mask the zeros from input image into label raster.
            if params['sample']['mask_reference']:
                np_label_raster = mask_image(np_input_image, np_label_raster)

            if info['dataset'] == 'trn':
                out_file = trn_hdf5
            elif info['dataset'] == 'val':
                out_file = val_hdf5
            elif info['dataset'] == 'tst':
                out_file = tst_hdf5
            else:
                raise ValueError(
                    f"Dataset value must be trn or val or tst. Provided value is {info['dataset']}"
                )

            meta_map, metadata = get_key_def("meta_map", params["global"],
                                             {}), None
            if info['meta'] is not None and isinstance(
                    info['meta'], str) and os.path.isfile(info['meta']):
                metadata = read_parameters(info['meta'])

            input_band_count = np_input_image.shape[
                2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
            assert input_band_count == params['global']['number_of_bands'], \
                f"The number of bands in the input image ({input_band_count}) and the parameter" \
                f"'number_of_bands' in the yaml file ({params['global']['number_of_bands']}) should be identical"

            np_label_raster = np.reshape(
                np_label_raster,
                (np_label_raster.shape[0], np_label_raster.shape[1], 1))
            number_samples, number_classes = samples_preparation(
                np_input_image, np_label_raster,
                params['global']['samples_size'],
                params['sample']['samples_dist'], number_samples,
                number_classes, out_file, info['dataset'],
                params['sample']['min_annotated_percent'], metadata)

            _tqdm.set_postfix(OrderedDict(number_samples=number_samples))
            out_file.flush()

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")
Exemple #16
0
def main(params: dict) -> None:
    """
    Function to manage details about the inference on segmentation task.
    1. Read the parameters from the config given.
    2. Read and load the state dict from the previous training or the given one.
    3. Make the inference on the data specifies in the config.
    -------
    :param params: (dict) Parameters found in the yaml config file.
    """
    # since = time.time()

    # PARAMETERS
    mode = get_key_def('mode', params, expected_type=str)
    task = get_key_def('task_name', params['task'], expected_type=str)
    model_name = get_key_def('model_name', params['model'], expected_type=str).lower()
    num_classes = len(get_key_def('classes_dict', params['dataset']).keys())
    modalities = read_modalities(get_key_def('modalities', params['dataset'], expected_type=str))
    BGR_to_RGB = get_key_def('BGR_to_RGB', params['dataset'], expected_type=bool)
    num_bands = len(modalities)
    debug = get_key_def('debug', params, default=False, expected_type=bool)
    # SETTING OUTPUT DIRECTORY
    try:
        state_dict = Path(params['inference']['state_dict_path']).resolve(strict=True)
    except FileNotFoundError:
        logging.info(
            f"\nThe state dict path directory '{params['inference']['state_dict_path']}' don't seem to be find," +
            f"we will try to locate a state dict path in the '{params['general']['save_weights_dir']}' " +
            f"specify during the training phase"
        )
        try:
            state_dict = Path(params['general']['save_weights_dir']).resolve(strict=True)
        except FileNotFoundError:
            raise logging.critical(
                f"\nThe state dict path directory '{params['general']['save_weights_dir']}'" +
                f" don't seem to be find either, please specify the path to a state dict"
            )
    # TODO add more detail in the parent folder
    working_folder = state_dict.parent.joinpath(f'inference_{num_bands}bands')
    logging.info("\nThe state dict path directory used '{}'".format(working_folder))
    Path.mkdir(working_folder, parents=True, exist_ok=True)

    # LOGGING PARAMETERS TODO put option not just mlflow
    experiment_name = get_key_def('project_name', params['general'], default='gdl-training')
    try:
        tracker_uri = get_key_def('uri', params['tracker'], default=None, expected_type=str)
        Path(tracker_uri).mkdir(exist_ok=True)
        run_name = get_key_def('run_name', params['tracker'], default='gdl')  # TODO change for something meaningful
        run_name = '{}_{}_{}'.format(run_name, mode, task)
        logging.info(f'\nInference and log files will be saved to: {working_folder}')
        # TODO change to fit whatever inport
        from mlflow import log_params, set_tracking_uri, set_experiment, start_run, log_artifact, log_metrics
        # tracking path + parameters logging
        set_tracking_uri(tracker_uri)
        set_experiment(experiment_name)
        start_run(run_name=run_name)
        log_params(dict_path(params, 'general'))
        log_params(dict_path(params, 'dataset'))
        log_params(dict_path(params, 'data'))
        log_params(dict_path(params, 'model'))
        log_params(dict_path(params, 'inference'))
    # meaning no logging tracker as been assigned or it doesnt exist in config/logging
    except ConfigKeyError:
        logging.info(
            "\nNo logging tracker as been assigned or the yaml config doesnt exist in 'config/tracker'."
            "\nNo tracker file will be save in that case."
        )

    # MANDATORY PARAMETERS
    img_dir_or_csv = get_key_def(
        'img_dir_or_csv_file', params['inference'], default=params['general']['raw_data_csv'], expected_type=str
    )
    if not (Path(img_dir_or_csv).is_dir() or Path(img_dir_or_csv).suffix == '.csv'):
        raise logging.critical(
            FileNotFoundError(
                f'\nCouldn\'t locate .csv file or directory "{img_dir_or_csv}" containing imagery for inference'
            )
        )
    # load the checkpoint
    try:
        # Sort by modification time (mtime) descending
        sorted_by_mtime_descending = sorted(
            [os.path.join(state_dict, x) for x in os.listdir(state_dict)], key=lambda t: -os.stat(t).st_mtime
        )
        last_checkpoint_save = find_first_file('checkpoint.pth.tar', sorted_by_mtime_descending)
        if last_checkpoint_save is None:
            raise FileNotFoundError
        # change the state_dict
        state_dict = last_checkpoint_save
    except FileNotFoundError as e:
        logging.error(f"\nNo file name 'checkpoint.pth.tar' as been found at '{state_dict}'")
        raise e

    task = get_key_def('task_name', params['task'], expected_type=str)
    # TODO change it next version for all task
    if task not in ['classification', 'segmentation']:
        raise logging.critical(
            ValueError(f'\nTask should be either "classification" or "segmentation". Got {task}')
        )

    # OPTIONAL PARAMETERS
    dontcare_val = get_key_def("ignore_index", params["training"], default=-1, expected_type=int)
    num_devices = get_key_def('num_gpus', params['training'], default=0, expected_type=int)
    default_max_used_ram = 25
    max_used_ram = get_key_def('max_used_ram', params['training'], default=default_max_used_ram, expected_type=int)
    max_used_perc = get_key_def('max_used_perc', params['training'], default=25, expected_type=int)
    scale = get_key_def('scale_data', params['augmentation'], default=[0, 1], expected_type=ListConfig)
    raster_to_vec = get_key_def('ras2vec', params['inference'], False) # FIXME not implemented with hydra

    # benchmark (ie when gkpgs are inputted along with imagery)
    dontcare = get_key_def("ignore_index", params["training"], -1)
    targ_ids = None  # TODO get_key_def('target_ids', params['sample'], None, expected_type=List)

    if debug:
        logging.warning(f'\nDebug mode activated. Some debug features may mobilize extra disk space and '
                        f'cause delays in execution.')

    # Assert that all items in target_ids are integers (ex.: to benchmark single-class model with multi-class labels)
    if targ_ids:
        for item in targ_ids:
            if not isinstance(item, int):
                raise ValueError(f'\nTarget id "{item}" in target_ids is {type(item)}, expected int.')

    logging.info(f'\nInferences will be saved to: {working_folder}\n\n')
    if not (0 <= max_used_ram <= 100):
        logging.warning(f'\nMax used ram parameter should be a percentage. Got {max_used_ram}. '
                        f'Will set default value of {default_max_used_ram} %')
        max_used_ram = default_max_used_ram

    # AWS
    bucket = None
    bucket_file_cache = []
    bucket_name = get_key_def('bucket_name', params['AWS'])

    # list of GPU devices that are available and unused. If no GPUs, returns empty dict
    gpu_devices_dict = get_device_ids(num_devices,
                                      max_used_ram_perc=max_used_ram,
                                      max_used_perc=max_used_perc)
    if gpu_devices_dict:
        chunk_size = calc_inference_chunk_size(gpu_devices_dict=gpu_devices_dict, max_pix_per_mb_gpu=50)
        logging.info(f"\nNumber of cuda devices requested: {num_devices}. "
                     f"\nCuda devices available: {gpu_devices_dict}. "
                     f"\nUsing {list(gpu_devices_dict.keys())[0]}\n\n")
        device = torch.device(f'cuda:{list(range(len(gpu_devices_dict.keys())))[0]}')
    else:
        chunk_size = get_key_def('chunk_size', params['inference'], default=512, expected_type=int)
        logging.warning(f"\nNo Cuda device available. This process will only run on CPU")
        device = torch.device('cpu')

    # CONFIGURE MODEL
    num_classes_backgr = add_background_to_num_class(task, num_classes)
    model, loaded_checkpoint, model_name = net(model_name=model_name,
                                               num_bands=num_bands,
                                               num_channels=num_classes_backgr,
                                               dontcare_val=dontcare_val,
                                               num_devices=1,
                                               net_params=params,
                                               inference_state_dict=state_dict)
    try:
        model.to(device)
    except RuntimeError:
        logging.info(f"\nUnable to use device. Trying device 0")
        device = torch.device(f'cuda' if gpu_devices_dict else 'cpu')
        model.to(device)

    # CREATE LIST OF INPUT IMAGES FOR INFERENCE
    try:
        # check if the data folder exist
        raw_data_dir = get_key_def('raw_data_dir', params['dataset'])
        my_data_path = Path(raw_data_dir).resolve(strict=True)
        logging.info("\nImage directory used '{}'".format(my_data_path))
        data_path = Path(my_data_path)
    except FileNotFoundError:
        raw_data_dir = get_key_def('raw_data_dir', params['dataset'])
        raise logging.critical(
            "\nImage directory '{}' doesn't exist, please change the path".format(raw_data_dir)
        )
    list_img = list_input_images(
        img_dir_or_csv, bucket_name, glob_patterns=["*.tif", "*.TIF"], in_case_of_path=str(data_path)
    )

    # VALIDATION: anticipate problems with imagery and label (if provided) before entering main for loop
    valid_gpkg_set = set()
    for info in tqdm(list_img, desc='Validating imagery'):
        # validate_raster(info['tif'], num_bands, meta_map)
        if 'gpkg' in info.keys() and info['gpkg'] and info['gpkg'] not in valid_gpkg_set:
            validate_num_classes(vector_file=info['gpkg'],
                                 num_classes=num_classes,
                                 attribute_name=info['attribute_name'],
                                 ignore_index=dontcare,
                                 target_ids=targ_ids)
            assert_crs_match(info['tif'], info['gpkg'])
            valid_gpkg_set.add(info['gpkg'])

    logging.info('\nSuccessfully validated imagery')
    if valid_gpkg_set:
        logging.info('\nSuccessfully validated label data for benchmarking')

    if task == 'classification':
        classifier(params, list_img, model, device,
                   working_folder)  # FIXME: why don't we load from checkpoint in classification?

    elif task == 'segmentation':
        gdf_ = []
        gpkg_name_ = []

        # TODO: Add verifications?
        if bucket:
            bucket.download_file(loaded_checkpoint, "saved_model.pth.tar")  # TODO: is this still valid?
            model, _ = load_from_checkpoint("saved_model.pth.tar", model)
        else:
            model, _ = load_from_checkpoint(loaded_checkpoint, model)

        # Save tracking TODO put option not just mlflow
        if 'tracker_uri' in locals() and 'run_name' in locals():
            mode = get_key_def('mode', params, expected_type=str)
            task = get_key_def('task_name', params['task'], expected_type=str)
            run_name = '{}_{}_{}'.format(run_name, mode, task)
            # tracking path + parameters logging
            set_tracking_uri(tracker_uri)
            set_experiment(experiment_name)
            start_run(run_name=run_name)
            log_params(dict_path(params, 'inference'))
            log_params(dict_path(params, 'dataset'))
            log_params(dict_path(params, 'model'))

        # LOOP THROUGH LIST OF INPUT IMAGES
        for info in tqdm(list_img, desc='Inferring from images', position=0, leave=True):
            img_name = Path(info['tif']).name
            local_gpkg = Path(info['gpkg']) if 'gpkg' in info.keys() and info['gpkg'] else None
            gpkg_name = local_gpkg.stem if local_gpkg else None
            if bucket:
                local_img = f"Images/{img_name}"
                bucket.download_file(info['tif'], local_img)
                inference_image = f"Classified_Images/{img_name.split('.')[0]}_inference.tif"
                if info['meta']:
                    if info['meta'] not in bucket_file_cache:
                        bucket_file_cache.append(info['meta'])
                        bucket.download_file(info['meta'], info['meta'].split('/')[-1])
                    info['meta'] = info['meta'].split('/')[-1]
            else:  # FIXME: else statement should support img['meta'] integration as well.
                local_img = Path(info['tif'])
                Path.mkdir(working_folder.joinpath(local_img.parent.name), parents=True, exist_ok=True)
                inference_image = working_folder.joinpath(local_img.parent.name,
                                                          f"{img_name.split('.')[0]}_inference.tif")
            temp_file = working_folder.joinpath(local_img.parent.name, f"{img_name.split('.')[0]}.dat")
            raster = rasterio.open(local_img, 'r')
            logging.info(f'\nReading original image: {raster.name}')
            inf_meta = raster.meta
            label = None
            if local_gpkg:
                logging.info(f'\nBurning label as raster: {local_gpkg}')
                local_img = clip_raster_with_gpkg(raster, local_gpkg)
                raster.close()
                raster = rasterio.open(local_img, 'r')
                logging.info(f'\nReading clipped image: {raster.name}')
                inf_meta = raster.meta
                label = vector_to_raster(vector_file=local_gpkg,
                                         input_image=raster,
                                         out_shape=(inf_meta['height'], inf_meta['width']),
                                         attribute_name=info['attribute_name'],
                                         fill=0,  # background value in rasterized vector.
                                         target_ids=targ_ids)
                if debug:
                    logging.debug(f'\nUnique values in loaded label as raster: {np.unique(label)}\n'
                                  f'Shape of label as raster: {label.shape}')
            pred, gdf = segmentation(param=params,
                                     input_image=raster,
                                     label_arr=label,
                                     num_classes=num_classes_backgr,
                                     gpkg_name=gpkg_name,
                                     model=model,
                                     chunk_size=chunk_size,
                                     device=device,
                                     scale=scale,
                                     BGR_to_RGB=BGR_to_RGB,
                                     tp_mem=temp_file,
                                     debug=debug)
            if gdf is not None:
                gdf_.append(gdf)
                gpkg_name_.append(gpkg_name)
            if local_gpkg and 'tracker_uri' in locals():
                pixelMetrics = ComputePixelMetrics(label, pred, num_classes_backgr)
                log_metrics(pixelMetrics.update(pixelMetrics.iou))
                log_metrics(pixelMetrics.update(pixelMetrics.dice))
            pred = pred[np.newaxis, :, :].astype(np.uint8)
            inf_meta.update({"driver": "GTiff",
                             "height": pred.shape[1],
                             "width": pred.shape[2],
                             "count": pred.shape[0],
                             "dtype": 'uint8',
                             "compress": 'lzw'})
            logging.info(f'\nSuccessfully inferred on {img_name}\nWriting to file: {inference_image}')
            with rasterio.open(inference_image, 'w+', **inf_meta) as dest:
                dest.write(pred)
            del pred
            try:
                temp_file.unlink()
            except OSError as e:
                logging.warning(f'File Error: {temp_file, e.strerror}')
            if raster_to_vec:
                start_vec = time.time()
                inference_vec = working_folder.joinpath(local_img.parent.name,
                                                        f"{img_name.split('.')[0]}_inference.gpkg")
                ras2vec(inference_image, inference_vec)
                end_vec = time.time() - start_vec
                logging.info('Vectorization completed in {:.0f}m {:.0f}s'.format(end_vec // 60, end_vec % 60))

        if len(gdf_) >= 1:
            if not len(gdf_) == len(gpkg_name_):
                raise logging.critical(ValueError('\nbenchmarking unable to complete'))
            all_gdf = pd.concat(gdf_)  # Concatenate all geo data frame into one geo data frame
            all_gdf.reset_index(drop=True, inplace=True)
            gdf_x = gpd.GeoDataFrame(all_gdf)
            bench_gpkg = working_folder / "benchmark.gpkg"
            gdf_x.to_file(bench_gpkg, driver="GPKG", index=False)
            logging.info(f'\nSuccessfully wrote benchmark geopackage to: {bench_gpkg}')
        # log_artifact(working_folder)
def main(params, config_path):
    """
    Function to train and validate a model for semantic segmentation.

    Process
    -------
    1. Model is instantiated and checkpoint is loaded from path, if provided in
       `your_config.yaml`.
    2. GPUs are requested according to desired amount of `num_gpus` and
       available GPUs.
    3. If more than 1 GPU is requested, model is cast to DataParallel model
    4. Dataloaders are created with `create_dataloader()`
    5. Loss criterion, optimizer and learning rate are set with
       `set_hyperparameters()` as requested in `config.yaml`.
    5. Using these hyperparameters, the application will try to minimize the
       loss on the training data and evaluate every epoch on the validation
       data.
    6. For every epoch, the application shows and logs the loss on "trn" and
       "val" datasets.
    7. For every epoch (if `batch_metrics: 1`), the application shows and logs
       the accuracy, recall and f-score on "val" dataset. Those metrics are
       also computed on each class.
    8. At the end of the training process, the application shows and logs the
       accuracy, recall and f-score on "tst" dataset. Those metrics are also
       computed on each class.

    -------
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.
    """
    params['global']['git_hash'] = get_git_hash()
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(
            f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.'
        )

    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    num_classes = params['global']['num_classes']
    task = params['global']['task']
    assert task == 'segmentation', f"The task should be segmentation. The provided value is {task}"
    num_classes_corrected = num_classes + 1  # + 1 for background # FIXME temporary patch for num_classes problem.

    data_path = Path(params['global']['data_path'])
    assert data_path.is_dir(), f'Could not locate data path {data_path}'
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = get_key_def('min_annotated_percent',
                                 params['sample']['sampling_method'],
                                 0,
                                 expected_type=int)
    num_bands = params['global']['number_of_bands']
    samples_folder_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'  # FIXME: won't check if folder has datetime suffix (if multiple folders)
    samples_folder = data_path.joinpath(samples_folder_name)
    batch_size = params['training']['batch_size']
    num_devices = params['global']['num_gpus']
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(
        num_devices) if torch.cuda.is_available() else []
    num_devices = len(lst_device_ids) if lst_device_ids else 0
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.
                          is_available() and lst_device_ids else 'cpu')

    tqdm.write(f'Creating dataloaders from data in {samples_folder}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_dataloader(
        samples_folder=samples_folder,
        batch_size=batch_size,
        num_devices=num_devices,
        params=params)
    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, model_name, criterion, optimizer, lr_scheduler = net(
        params,
        num_classes_corrected)  # pretrained could become a yaml parameter.
    tqdm.write(
        f'Instantiated {model_name} model with {num_classes_corrected} output channels.\n'
    )
    bucket_name = get_key_def('bucket_name', params['global'])

    # mlflow tracking path + parameters logging
    set_tracking_uri(
        get_key_def('mlflow_uri', params['global'], default="./mlruns"))
    set_experiment('gdl-training')
    log_params(params['training'])
    log_params(params['global'])
    log_params(params['sample'])

    modelname = config_path.stem
    output_path = samples_folder.joinpath('model') / modelname
    if output_path.is_dir():
        output_path = output_path.joinpath(f"_{now}")
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))
    tqdm.write(f'Model and log files will be saved to: {output_path}\n\n')

    if bucket_name:
        from utils.aws import download_s3_files
        bucket, bucket_output_path, output_path, data_path = download_s3_files(
            bucket_name=bucket_name,
            data_path=data_path,
            output_path=output_path)

    since = time.time()
    best_loss = 999
    last_vis_epoch = 0

    progress_log = output_path / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(
            tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep',
                     'time'))  # Add header

    trn_log = InformationLogger('trn')
    val_log = InformationLogger('val')
    tst_log = InformationLogger('tst')
    filename = output_path.joinpath('checkpoint.pth.tar')

    # VISUALIZATION: generate pngs of inputs, labels and outputs
    vis_batch_range = get_key_def('vis_batch_range', params['visualization'],
                                  None)
    if vis_batch_range is not None:
        # Make sure user-provided range is a tuple with 3 integers (start, finish, increment). Check once for all visualization tasks.
        assert isinstance(vis_batch_range,
                          list) and len(vis_batch_range) == 3 and all(
                              isinstance(x, int) for x in vis_batch_range)
        vis_at_init_dataset = get_key_def('vis_at_init_dataset',
                                          params['visualization'], 'val')

        # Visualization at initialization. Visualize batch range before first eopch.
        if get_key_def('vis_at_init', params['visualization'], False):
            tqdm.write(
                f'Visualizing initialized model on batch range {vis_batch_range} from {vis_at_init_dataset} dataset...\n'
            )
            vis_from_dataloader(
                params=params,
                eval_loader=val_dataloader
                if vis_at_init_dataset == 'val' else tst_dataloader,
                model=model,
                ep_num=0,
                output_path=output_path,
                dataset=vis_at_init_dataset,
                device=device,
                vis_batch_range=vis_batch_range)

    for epoch in range(0, params['training']['num_epochs']):
        print(
            f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}'
        )

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes_corrected,
                           batch_size=batch_size,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           vis_params=params,
                           device=device,
                           debug=debug)
        trn_log.add_values(trn_report,
                           epoch,
                           ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(
            eval_loader=val_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes_corrected,
            batch_size=batch_size,
            ep_idx=epoch,
            progress_log=progress_log,
            vis_params=params,
            batch_metrics=params['training']['batch_metrics'],
            dataset='val',
            device=device,
            debug=debug)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch)
        else:
            val_log.add_values(val_report,
                               epoch,
                               ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            tqdm.write("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict(
            ) if num_devices > 1 else model.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'params': params,
                    'model': state_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict()
                }, filename)
            if epoch == 0:
                log_artifact(filename)
            if bucket_name:
                bucket_filename = bucket_output_path.joinpath(
                    'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

            # VISUALIZATION: generate png of test samples, labels and outputs for visualisation to follow training performance
            vis_at_checkpoint = get_key_def('vis_at_checkpoint',
                                            params['visualization'], False)
            ep_vis_min_thresh = get_key_def('vis_at_ckpt_min_ep_diff',
                                            params['visualization'], 4)
            vis_at_ckpt_dataset = get_key_def('vis_at_ckpt_dataset',
                                              params['visualization'], 'val')
            if vis_batch_range is not None and vis_at_checkpoint and epoch - last_vis_epoch >= ep_vis_min_thresh:
                if last_vis_epoch == 0:
                    tqdm.write(
                        f'Visualizing with {vis_at_ckpt_dataset} dataset samples on checkpointed model for'
                        f'batches in range {vis_batch_range}')
                vis_from_dataloader(
                    params=params,
                    eval_loader=val_dataloader
                    if vis_at_ckpt_dataset == 'val' else tst_dataloader,
                    model=model,
                    ep_num=epoch + 1,
                    output_path=output_path,
                    dataset=vis_at_ckpt_dataset,
                    device=device,
                    vis_batch_range=vis_batch_range)
                last_vis_epoch = epoch

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now,
                                params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        print(
            f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s'
        )

    # load checkpoint model and evaluate it on test dataset.
    if int(
            params['training']['num_epochs']
    ) > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(
            eval_loader=tst_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes_corrected,
            batch_size=batch_size,
            ep_idx=params['training']['num_epochs'],
            progress_log=progress_log,
            vis_params=params,
            batch_metrics=params['training']['batch_metrics'],
            dataset='tst',
            device=device)
        tst_log.add_values(tst_report, params['training']['num_epochs'])

        if bucket_name:
            bucket_filename = bucket_output_path.joinpath('last_epoch.pth.tar')
            bucket.upload_file(
                "output.txt",
                bucket_output_path.joinpath(f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
def evaluation(eval_loader, model, criterion, num_classes, batch_size, ep_idx, progress_log, batch_metrics=None, dataset='val', device=None, debug=False):
    """
    Evaluate the model and return the updated metrics
    :param eval_loader: data loader
    :param model: model to evaluate
    :param criterion: loss criterion
    :param num_classes: number of classes
    :param batch_size: number of samples to process simultaneously
    :param ep_idx: epoch index (for hypertrainer log)
    :param progress_log: progress log file (for hypertrainer log)
    :param batch_metrics: (int) Metrics computed every (int) batches. If left blank, will not perform metrics.
    :param dataset: (str) 'val or 'tst'
    :param device: device used by pytorch (cpu ou cuda)
    :return: (dict) eval_metrics
    """
    eval_metrics = create_metrics_dict(num_classes)
    model.eval()

    with tqdm(eval_loader, dynamic_ncols=True, desc=f'Iterating {dataset} batches with {device.type}') as _tqdm:
        for batch_index, data in enumerate(_tqdm):
            progress_log.open('a', buffering=1).write(tsv_line(ep_idx, dataset, batch_index, len(eval_loader), time.time()))

            with torch.no_grad():
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                labels_flatten = labels

                outputs = model(inputs)
                outputs_flatten = outputs

                loss = criterion(outputs, labels)

                eval_metrics['loss'].update(loss.item(), batch_size)

                if (dataset == 'val') and (batch_metrics is not None):
                    # Compute metrics every n batches. Time consuming.
                    assert batch_metrics <= len(_tqdm), f"Batch_metrics ({batch_metrics} is smaller than batch size " \
                        f"{len(_tqdm)}. Metrics in validation loop won't be computed"
                    if (batch_index+1) % batch_metrics == 0:   # +1 to skip val loop at very beginning
                        a, segmentation = torch.max(outputs_flatten, dim=1)
                        eval_metrics = report_classification(segmentation, labels_flatten, batch_size, eval_metrics,
                                                             ignore_index=get_key_def("ignore_index", params["training"], None))
                elif dataset == 'tst':
                    a, segmentation = torch.max(outputs_flatten, dim=1)
                    eval_metrics = report_classification(segmentation, labels_flatten, batch_size, eval_metrics,
                                                         ignore_index=get_key_def("ignore_index", params["training"], None))

                _tqdm.set_postfix(OrderedDict(dataset=dataset, loss=f'{eval_metrics["loss"].avg:.4f}'))

                if debug and device.type == 'cuda':
                    res, mem = gpu_stats(device=device.index)
                    _tqdm.set_postfix(OrderedDict(device=device, gpu_perc=f'{res.gpu} %',
                                                  gpu_RAM=f'{mem.used/(1024**2):.0f}/{mem.total/(1024**2):.0f} MiB'))

    logging.info(f"{dataset} Loss: {eval_metrics['loss'].avg}")
    if batch_metrics is not None:
        logging.info(f"{dataset} precision: {eval_metrics['precision'].avg}")
        logging.info(f"{dataset} recall: {eval_metrics['recall'].avg}")
        logging.info(f"{dataset} fscore: {eval_metrics['fscore'].avg}")

    return eval_metrics
Exemple #19
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path', net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'], True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False, progress=True, num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        try:
            model = models.segmentation.deeplabv3_resnet101(pretrained=False, progress=True, in_channels=num_bands,
                                                            num_classes=num_channels, aux_loss=None)
        except:
            assert num_bands==3, 'Edit torchvision scripts segmentation.py and resnet.py to build deeplabv3_resnet ' \
                                 'with more or less than 3 bands'
            model = models.segmentation.deeplabv3_resnet101(pretrained=False, progress=True,
                                                            num_classes=num_channels, aux_loss=None)
    else:
        raise ValueError(f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'], False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'], True)
        normalized = get_key_def('coordconv_normalized', net_params['global'], True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel', net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model, centered=centered, normalized=normalized, noise=noise,
                                                radius_channel=radius_channel, scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    elif pretrained and (model_name == ('deeplabv3_resnet101' or 'fcn_resnet101')):
        print(f'Retrieving coco checkpoint for {model_name}...\n')
        if model_name == 'deeplabv3_resnet101':  # default to pretrained on coco (21 classes)
            coco_model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=True, num_classes=21, aux_loss=None)
        else:
            coco_model = models.segmentation.fcn_resnet101(pretrained=True, progress=True, num_classes=21, aux_loss=None)
        checkpoint = coco_model.state_dict()
        # Place entire state_dict inside 'model' key for compatibility with the rest of GDL workflow
        temp_checkpoint = {}
        temp_checkpoint['model'] = {k: v for k, v in checkpoint.items()}
        del coco_model, checkpoint
        checkpoint = temp_checkpoint
    elif pretrained:
        warnings.warn(f'No pretrained checkpoint found for {model_name}.')
        checkpoint = None
    else:
        checkpoint = None

    return model, checkpoint, model_name
Exemple #20
0
def train(train_loader,
          model,
          criterion,
          optimizer,
          scheduler,
          num_classes,
          batch_size,
          task,
          ep_idx,
          progress_log,
          vis_params,
          device,
          debug=False):
    """
    Train the model and return the metrics of the training epoch
    :param train_loader: training data loader
    :param model: model to train
    :param criterion: loss criterion
    :param optimizer: optimizer to use
    :param scheduler: learning rate scheduler
    :param num_classes: number of classes
    :param batch_size: number of samples to process simultaneously
    :param task: segmentation or classification
    :param ep_idx: epoch index (for hypertrainer log)
    :param progress_log: progress log file (for hypertrainer log)
    :param vis_params: (dict) Parameters found in the yaml config file. Named vis_params because they are only used for
                        visualization functions.
    :param device: device used by pytorch (cpu ou cuda)
    :return: Updated training loss
    """
    model.train()
    train_metrics = create_metrics_dict(num_classes)
    vis_at_train = get_key_def('vis_at_train', vis_params['visualization'],
                               False)
    vis_batch_range = get_key_def('vis_batch_range',
                                  vis_params['visualization'], None)
    min_vis_batch, max_vis_batch, increment = vis_batch_range

    with tqdm(train_loader,
              desc=f'Iterating train batches with {device.type}') as _tqdm:
        for batch_index, data in enumerate(_tqdm):
            progress_log.open('a', buffering=1).write(
                tsv_line(ep_idx, 'trn', batch_index, len(train_loader),
                         time.time()))

            inputs = data['sat_img'].to(device)
            labels = data['map_img'].to(device)

            # forward
            optimizer.zero_grad()
            outputs = model(inputs)
            # added for torchvision models that output an OrderedDict with outputs in 'out' key.
            # More info: https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/
            if isinstance(outputs, OrderedDict):
                outputs = outputs['out']

            if vis_batch_range is not None and vis_at_train and batch_index in range(
                    min_vis_batch, max_vis_batch, increment):
                vis_path = progress_log.parent.joinpath('visualization')
                if ep_idx == 0:
                    tqdm.write(
                        f'Visualizing on train outputs for batches in range {vis_batch_range}. All images will be saved to {vis_path}\n'
                    )
                vis_from_batch(params,
                               inputs,
                               outputs,
                               batch_index=batch_index,
                               vis_path=vis_path,
                               labels=labels,
                               dataset='trn',
                               ep_num=ep_idx + 1)

            loss = criterion(outputs, labels)

            train_metrics['loss'].update(loss.item(), batch_size)

            if device.type == 'cuda' and debug:
                res, mem = gpu_stats(device=device.index)
                _tqdm.set_postfix(
                    OrderedDict(
                        trn_loss=f'{train_metrics["loss"].val:.2f}',
                        gpu_perc=f'{res.gpu} %',
                        gpu_RAM=
                        f'{mem.used / (1024 ** 2):.0f}/{mem.total / (1024 ** 2):.0f} MiB',
                        lr=optimizer.param_groups[0]['lr'],
                        img=data['sat_img'].numpy().shape[1:],
                        smpl=data['map_img'].numpy().shape,
                        bs=batch_size,
                        out_vals=np.unique(
                            outputs[0].argmax(dim=0).detach().cpu().numpy())))

            loss.backward()
            optimizer.step()

    scheduler.step()
    print(f'Training Loss: {train_metrics["loss"].avg:.4f}')
    return train_metrics
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path',
                                        net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'],
                             True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    # TODO: find a way to maybe implement it in classification one day
    if 'concatenate_depth' in net_params['global']:
        # Read the concatenation point
        conc_point = net_params['global']['concatenate_depth']

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout,
                                            dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=pretrained,
                                                  progress=True,
                                                  num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)
            classifier = list(model.classifier.children())
            model.classifier = nn.Sequential(*classifier[:-1])
            model.classifier.add_module(
                '4',
                nn.Conv2d(classifier[-1].in_channels,
                          num_channels,
                          kernel_size=(1, 1)))
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            print('Testing with 4 bands, concatenating at {}.'.format(
                conc_point))

            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)

            if conc_point == 'baseline':
                conv1 = model.backbone._modules['conv1'].weight.detach().numpy(
                )
                depth = np.expand_dims(
                    conv1[:, 1,
                          ...], axis=1)  # reuse green weights for infrared.
                conv1 = np.append(conv1, depth, axis=1)
                conv1 = torch.from_numpy(conv1).float()
                model.backbone._modules['conv1'].weight = nn.Parameter(
                    conv1, requires_grad=True)
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
            else:
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
                ###################
                #conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                #depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
                #conv1 = np.append(conv1, depth, axis=1)
                #conv1 = torch.from_numpy(conv1).float()
                #model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                ###################
                model = LayersEnsemble(model, conc_point=conc_point)

    elif model_name in lm_smp.keys():
        lsmp = lm_smp[model_name]
        # TODO: add possibility of our own weights
        lsmp['params'][
            'encoder_weights'] = "imagenet" if 'pretrained' in model_name.split(
                "_") else None
        lsmp['params']['in_channels'] = num_bands
        lsmp['params']['classes'] = num_channels
        lsmp['params']['activation'] = None

        model = lsmp['fct'](**lsmp['params'])

    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'],
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'],
                               True)
        normalized = get_key_def('coordconv_normalized', net_params['global'],
                                 True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(
        ), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(
        ), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    else:
        checkpoint = None

    return model, checkpoint, model_name
def main(params):
    """
    Training and validation datasets preparation.
    :param params: (dict) Parameters found in the yaml config file.

    """
    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    bucket_file_cache = []

    assert params['global'][
        'task'] == 'segmentation', f"images_to_samples.py isn't necessary when performing classification tasks"

    # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS.
    bucket_name = params['global']['bucket_name']
    data_path = Path(params['global']['data_path'])
    Path.mkdir(data_path, exist_ok=True, parents=True)
    csv_file = params['sample']['prep_csv_file']
    val_percent = params['sample']['val_percent']
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = params['sample']['sampling']['map']
    num_bands = params['global']['number_of_bands']
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(f'Debug mode activate. Execution may take longer...')

    final_samples_folder = None
    if bucket_name:
        s3 = boto3.resource('s3')
        bucket = s3.Bucket(bucket_name)
        bucket.download_file(csv_file, 'samples_prep.csv')
        list_data_prep = read_csv('samples_prep.csv')
        if data_path:
            final_samples_folder = os.path.join(data_path, "samples")
        else:
            final_samples_folder = "samples"
        samples_folder = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'  # TODO: validate this is preferred name structure

    else:
        list_data_prep = read_csv(csv_file)
        samples_folder = data_path.joinpath(
            f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'
        )

    if samples_folder.is_dir():
        warnings.warn(
            f'Data path exists: {samples_folder}. Suffix will be added to directory name.'
        )
        samples_folder = Path(str(samples_folder) + '_' + now)
    else:
        tqdm.write(f'Writing samples to {samples_folder}')
    Path.mkdir(samples_folder, exist_ok=False
               )  # FIXME: what if we want to append samples to existing hdf5?
    tqdm.write(f'Samples will be written to {samples_folder}\n\n')

    tqdm.write(
        f'\nSuccessfully read csv file: {Path(csv_file).stem}\nNumber of rows: {len(list_data_prep)}\nCopying first entry:\n{list_data_prep[0]}\n'
    )
    ignore_index = get_key_def('ignore_index', params['training'], -1)

    for info in tqdm(list_data_prep,
                     position=0,
                     desc=f'Asserting existence of tif and gpkg files in csv'):
        assert Path(info['tif']).is_file(
        ), f'Could not locate "{info["tif"]}". Make sure file exists in this directory.'
        assert Path(info['gpkg']).is_file(
        ), f'Could not locate "{info["gpkg"]}". Make sure file exists in this directory.'
    if debug:
        for info in tqdm(
                list_data_prep,
                position=0,
                desc=f"Validating presence of {params['global']['num_classes']} "
                f"classes in attribute \"{info['attribute_name']}\" for vector "
                f"file \"{Path(info['gpkg']).stem}\""):
            validate_num_classes(info['gpkg'], params['global']['num_classes'],
                                 info['attribute_name'], ignore_index)
        with tqdm(list_data_prep,
                  position=0,
                  desc=f"Checking validity of features in vector files"
                  ) as _tqdm:
            invalid_features = {}
            for info in _tqdm:
                # Extract vector features to burn in the raster image
                with fiona.open(
                        info['gpkg'],
                        'r') as src:  # TODO: refactor as independent function
                    lst_vector = [vector for vector in src]
                shapes = lst_ids(list_vector=lst_vector,
                                 attr_name=info['attribute_name'])
                for index, item in enumerate(
                        tqdm([v for vecs in shapes.values() for v in vecs],
                             leave=False,
                             position=1)):
                    # geom must be a valid GeoJSON geometry type and non-empty
                    geom, value = item
                    geom = getattr(geom, '__geo_interface__', None) or geom
                    if not is_valid_geom(geom):
                        gpkg_stem = str(Path(info['gpkg']).stem)
                        if gpkg_stem not in invalid_features.keys(
                        ):  # create key with name of gpkg
                            invalid_features[gpkg_stem] = []
                        if lst_vector[index]["id"] not in invalid_features[
                                gpkg_stem]:  # ignore feature is already appended
                            invalid_features[gpkg_stem].append(
                                lst_vector[index]["id"])
            assert len(
                invalid_features.values()
            ) == 0, f'Invalid geometry object(s) for "gpkg:ids": \"{invalid_features}\"'

    number_samples = {'trn': 0, 'val': 0, 'tst': 0}
    number_classes = 0

    # 'sampling' ordereddict validation
    check_sampling_dict()

    pixel_classes = {}
    # creates pixel_classes dict and keys
    for i in range(0, params['global']['num_classes'] + 1):
        pixel_classes.update({i: 0})
    pixel_classes.update(
        {ignore_index: 0}
    )  # FIXME: pixel_classes dict needs to be populated with classes obtained from target

    trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(
        params, samples_folder)

    # For each row in csv: (1) burn vector file to raster, (2) read input raster image, (3) prepare samples
    with tqdm(list_data_prep,
              position=0,
              leave=False,
              desc=f'Preparing samples') as _tqdm:
        for info in _tqdm:
            _tqdm.set_postfix(
                OrderedDict(tif=f'{Path(info["tif"]).stem}',
                            sample_size=params['global']['samples_size']))
            try:
                if bucket_name:
                    bucket.download_file(
                        info['tif'], "Images/" + info['tif'].split('/')[-1])
                    info['tif'] = "Images/" + info['tif'].split('/')[-1]
                    if info['gpkg'] not in bucket_file_cache:
                        bucket_file_cache.append(info['gpkg'])
                        bucket.download_file(info['gpkg'],
                                             info['gpkg'].split('/')[-1])
                    info['gpkg'] = info['gpkg'].split('/')[-1]
                    if info['meta']:
                        if info['meta'] not in bucket_file_cache:
                            bucket_file_cache.append(info['meta'])
                            bucket.download_file(info['meta'],
                                                 info['meta'].split('/')[-1])
                        info['meta'] = info['meta'].split('/')[-1]

                with rasterio.open(info['tif'], 'r') as raster:
                    # Burn vector file in a raster file
                    np_label_raster = vector_to_raster(
                        vector_file=info['gpkg'],
                        input_image=raster,
                        attribute_name=info['attribute_name'],
                        fill=get_key_def('ignore_idx',
                                         get_key_def('training', params, {}),
                                         0))
                    # Read the input raster image
                    np_input_image = image_reader_as_array(
                        input_image=raster,
                        scale=get_key_def('scale_data', params['global'],
                                          None),
                        aux_vector_file=get_key_def('aux_vector_file',
                                                    params['global'], None),
                        aux_vector_attrib=get_key_def('aux_vector_attrib',
                                                      params['global'], None),
                        aux_vector_ids=get_key_def('aux_vector_ids',
                                                   params['global'], None),
                        aux_vector_dist_maps=get_key_def(
                            'aux_vector_dist_maps', params['global'], True),
                        aux_vector_dist_log=get_key_def(
                            'aux_vector_dist_log', params['global'], True),
                        aux_vector_scale=get_key_def('aux_vector_scale',
                                                     params['global'], None))

                # Mask the zeros from input image into label raster.
                if params['sample']['mask_reference']:
                    np_label_raster = mask_image(np_input_image,
                                                 np_label_raster)

                if info['dataset'] == 'trn':
                    out_file = trn_hdf5
                    val_file = val_hdf5
                elif info['dataset'] == 'tst':
                    out_file = tst_hdf5
                else:
                    raise ValueError(
                        f"Dataset value must be trn or val or tst. Provided value is {info['dataset']}"
                    )

                meta_map, metadata = get_key_def("meta_map", params["global"],
                                                 {}), None
                if info['meta'] is not None and isinstance(
                        info['meta'], str) and Path(info['meta']).is_file():
                    metadata = read_parameters(info['meta'])

                # FIXME: think this through. User will have to calculate the total number of bands including meta layers and
                #  specify it in yaml. Is this the best approach? What if metalayers are added on the fly ?
                input_band_count = np_input_image.shape[
                    2] + MetaSegmentationDataset.get_meta_layer_count(meta_map)
                # FIXME: could this assert be done before getting into this big for loop?
                assert input_band_count == num_bands, \
                    f"The number of bands in the input image ({input_band_count}) and the parameter" \
                    f"'number_of_bands' in the yaml file ({params['global']['number_of_bands']}) should be identical"

                np_label_raster = np.reshape(
                    np_label_raster,
                    (np_label_raster.shape[0], np_label_raster.shape[1], 1))
                number_samples, number_classes = samples_preparation(
                    np_input_image, np_label_raster, samples_size, overlap,
                    number_samples, number_classes, out_file, val_percent,
                    val_file, info['dataset'], pixel_classes, metadata)

                _tqdm.set_postfix(OrderedDict(number_samples=number_samples))
                out_file.flush()
            except Exception as e:
                warnings.warn(
                    f'An error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and '
                    f'{Path(info["gpkg"]).stem} (gpkg). Error: "{e}"')
                continue

    trn_hdf5.close()
    val_hdf5.close()
    tst_hdf5.close()

    pixel_total = 0
    # adds up the number of pixels for each class in pixel_classes dict
    for i in pixel_classes:
        pixel_total += pixel_classes[i]

    # prints the proportion of pixels of each class for the samples created
    for i in pixel_classes:
        print('Pixels from class', i, ':',
              round((pixel_classes[i] / pixel_total) * 100, 1), '%')

    print("Number of samples created: ", number_samples)

    if bucket_name and final_samples_folder:
        print('Transfering Samples to the bucket')
        bucket.upload_file(samples_folder + "/trn_samples.hdf5",
                           final_samples_folder + '/trn_samples.hdf5')
        bucket.upload_file(samples_folder + "/val_samples.hdf5",
                           final_samples_folder + '/val_samples.hdf5')
        bucket.upload_file(samples_folder + "/tst_samples.hdf5",
                           final_samples_folder + '/tst_samples.hdf5')

    print("End of process")
Exemple #23
0
def create_dataloader(data_path,
                      batch_size,
                      task,
                      num_devices,
                      params,
                      samples_folder=None):
    """
    Function to create dataloader objects for training, validation and test datasets.
    :param data_path: (str) path to the samples folder
    :param batch_size: (int) batch size
    :param task: (str) classification or segmentation
    :param num_devices: (int) number of GPUs used
    :param samples_folder: path to folder containting .hdf5 files if task is segmentation
    :param params: (dict) Parameters found in the yaml config file.
    :return: trn_dataloader, val_dataloader, tst_dataloader
    """
    assert Path(samples_folder).is_dir(), f'Could not locate: {samples_folder}'
    assert len([f for f in Path(samples_folder).glob('**/*.hdf5')
                ]) >= 1, f"Couldn't locate .hdf5 files in {samples_folder}"
    num_samples = get_num_samples(samples_path=samples_folder, params=params)
    print(f"Number of samples : {num_samples}\n")
    meta_map = get_key_def("meta_map", params["global"], {})
    if not meta_map:
        dataset_constr = CreateDataset.SegmentationDataset
    else:
        dataset_constr = functools.partial(
            CreateDataset.MetaSegmentationDataset, meta_map=meta_map)
    dontcare = get_key_def("ignore_index", params["training"], None)
    if dontcare == 0:
        warnings.warn(
            "The 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero;"
            " all valid class indices should be consecutive, and start at 0. The 'dontcare' value"
            " will be remapped to -1 while loading the dataset, and inside the config from now on."
        )
        params["training"]["ignore_index"] = -1
    datasets = []

    for subset in ["trn", "val", "tst"]:
        datasets.append(
            dataset_constr(samples_folder,
                           subset,
                           max_sample_count=num_samples[subset],
                           dontcare=dontcare,
                           transform=aug.compose_transforms(params, subset)))
    trn_dataset, val_dataset, tst_dataset = datasets

    # https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/5
    num_workers = num_devices * 4 if num_devices > 1 else 4

    # Shuffle must be set to True.
    trn_dataloader = DataLoader(trn_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=True,
                                drop_last=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=False)
    tst_dataloader = DataLoader(
        tst_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False) if num_samples['tst'] > 0 else None

    return trn_dataloader, val_dataloader, tst_dataloader
Exemple #24
0
def main(params, config_path):
    """
    Function to train and validate a model for semantic segmentation.

    Process
    -------
    1. Model is instantiated and checkpoint is loaded from path, if provided in
       `your_config.yaml`.
    2. GPUs are requested according to desired amount of `num_gpus` and
       available GPUs.
    3. If more than 1 GPU is requested, model is cast to DataParallel model
    4. Dataloaders are created with `create_dataloader()`
    5. Loss criterion, optimizer and learning rate are set with
       `set_hyperparameters()` as requested in `config.yaml`.
    5. Using these hyperparameters, the application will try to minimize the
       loss on the training data and evaluate every epoch on the validation
       data.
    6. For every epoch, the application shows and logs the loss on "trn" and
       "val" datasets.
    7. For every epoch (if `batch_metrics: 1`), the application shows and logs
       the accuracy, recall and f-score on "val" dataset. Those metrics are
       also computed on each class.
    8. At the end of the training process, the application shows and logs the
       accuracy, recall and f-score on "tst" dataset. Those metrics are also
       computed on each class.

    -------
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.
    """
    now = datetime.now().strftime("%Y-%m-%d_%H-%M")

    # MANDATORY PARAMETERS
    num_classes = get_key_def('num_classes', params['global'], expected_type=int)
    num_classes_corrected = num_classes + 1  # + 1 for background # FIXME temporary patch for num_classes problem.
    num_bands = get_key_def('number_of_bands', params['global'], expected_type=int)
    batch_size = get_key_def('batch_size', params['training'], expected_type=int)
    eval_batch_size = get_key_def('eval_batch_size', params['training'], expected_type=int, default=batch_size)
    num_epochs = get_key_def('num_epochs', params['training'], expected_type=int)
    model_name = get_key_def('model_name', params['global'], expected_type=str).lower()
    BGR_to_RGB = get_key_def('BGR_to_RGB', params['global'], expected_type=bool)

    # OPTIONAL PARAMETERS
    # basics
    debug = get_key_def('debug_mode', params['global'], default=False, expected_type=bool)
    task = get_key_def('task', params['global'], default='segmentation', expected_type=str)
    if not task == 'segmentation':
        raise ValueError(f"The task should be segmentation. The provided value is {task}")
    dontcare_val = get_key_def("ignore_index", params["training"], default=-1, expected_type=int)
    crop_size = get_key_def('target_size', params['training'], default=None, expected_type=int)
    batch_metrics = get_key_def('batch_metrics', params['training'], default=None, expected_type=int)
    meta_map = get_key_def("meta_map", params["global"], default=None)
    if meta_map and not Path(meta_map).is_file():
        raise FileNotFoundError(f'Couldn\'t locate {meta_map}')
    bucket_name = get_key_def('bucket_name', params['global'])  # AWS
    scale = get_key_def('scale_data', params['global'], default=[0, 1], expected_type=List)

    # model params
    loss_fn = get_key_def('loss_fn', params['training'], default='CrossEntropy', expected_type=str)
    class_weights = get_key_def('class_weights', params['training'], default=None, expected_type=Sequence)
    if class_weights:
        verify_weights(num_classes_corrected, class_weights)
    optimizer = get_key_def('optimizer', params['training'], default='adam', expected_type=str)
    pretrained = get_key_def('pretrained', params['training'], default=True, expected_type=bool)
    train_state_dict_path = get_key_def('state_dict_path', params['training'], default=None, expected_type=str)
    if train_state_dict_path and not Path(train_state_dict_path).is_file():
        raise FileNotFoundError(f'Could not locate pretrained checkpoint for training: {train_state_dict_path}')
    dropout_prob = get_key_def('dropout_prob', params['training'], default=None, expected_type=float)
    # Read the concatenation point
    # TODO: find a way to maybe implement it in classification one day
    conc_point = get_key_def('concatenate_depth', params['global'], None)

    # gpu parameters
    num_devices = get_key_def('num_gpus', params['global'], default=0, expected_type=int)
    if num_devices and not num_devices >= 0:
        raise ValueError("missing mandatory num gpus parameter")

    # mlflow logging
    mlflow_uri = get_key_def('mlflow_uri', params['global'], default="./mlruns")
    Path(mlflow_uri).mkdir(exist_ok=True)
    experiment_name = get_key_def('mlflow_experiment_name', params['global'], default='gdl-training', expected_type=str)
    run_name = get_key_def('mlflow_run_name', params['global'], default='gdl', expected_type=str)

    # parameters to find hdf5 samples
    data_path = Path(get_key_def('data_path', params['global'], './data', expected_type=str))
    samples_size = get_key_def("samples_size", params["global"], default=1024, expected_type=int)
    overlap = get_key_def("overlap", params["sample"], default=5, expected_type=int)
    min_annot_perc = get_key_def('min_annotated_percent', params['sample']['sampling_method'], default=0,
                                 expected_type=int)
    if not data_path.is_dir():
        raise FileNotFoundError(f'Could not locate data path {data_path}')
    samples_folder_name = (f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'
                           f'_{experiment_name}')
    samples_folder = data_path.joinpath(samples_folder_name)

    # visualization parameters
    vis_at_train = get_key_def('vis_at_train', params['visualization'], default=False)
    vis_at_eval = get_key_def('vis_at_evaluation', params['visualization'], default=False)
    vis_batch_range = get_key_def('vis_batch_range', params['visualization'], default=None)
    vis_at_checkpoint = get_key_def('vis_at_checkpoint', params['visualization'], default=False)
    ep_vis_min_thresh = get_key_def('vis_at_ckpt_min_ep_diff', params['visualization'], default=1, expected_type=int)
    vis_at_ckpt_dataset = get_key_def('vis_at_ckpt_dataset', params['visualization'], 'val')
    colormap_file = get_key_def('colormap_file', params['visualization'], None)
    heatmaps = get_key_def('heatmaps', params['visualization'], False)
    heatmaps_inf = get_key_def('heatmaps', params['inference'], False)
    grid = get_key_def('grid', params['visualization'], False)
    mean = get_key_def('mean', params['training']['normalization'])
    std = get_key_def('std', params['training']['normalization'])
    vis_params = {'colormap_file': colormap_file, 'heatmaps': heatmaps, 'heatmaps_inf': heatmaps_inf, 'grid': grid,
                  'mean': mean, 'std': std, 'vis_batch_range': vis_batch_range, 'vis_at_train': vis_at_train,
                  'vis_at_eval': vis_at_eval, 'ignore_index': dontcare_val, 'inference_input_path': None}

    # coordconv parameters
    coordconv_params = {}
    for param, val in params['global'].items():
        if 'coordconv' in param:
            coordconv_params[param] = val

    # add git hash from current commit to parameters if available. Parameters will be saved to model's .pth.tar
    params['global']['git_hash'] = get_git_hash()

    # automatic model naming with unique id for each training
    model_id = config_path.stem
    output_path = samples_folder.joinpath('model') / model_id
    if output_path.is_dir():
        last_mod_time_suffix = datetime.fromtimestamp(output_path.stat().st_mtime).strftime('%Y%m%d-%H%M%S')
        archive_output_path = samples_folder.joinpath('model') / f"{model_id}_{last_mod_time_suffix}"
        shutil.move(output_path, archive_output_path)
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))  # copy yaml to output path where model will be saved

    import logging.config  # See: https://docs.python.org/2.4/lib/logging-config-fileformat.html
    log_config_path = Path('utils/logging.conf').absolute()
    logfile = f'{output_path}/{model_id}.log'
    logfile_debug = f'{output_path}/{model_id}_debug.log'
    console_level_logging = 'INFO' if not debug else 'DEBUG'
    logging.config.fileConfig(log_config_path, defaults={'logfilename': logfile,
                                                         'logfilename_debug': logfile_debug,
                                                         'console_level': console_level_logging})

    logging.info(f'Model and log files will be saved to: {output_path}\n\n')
    if debug:
        logging.warning(f'Debug mode activated. Some debug features may mobilize extra disk space and '
                        f'cause delays in execution.')
    if dontcare_val < 0 and vis_batch_range:
        logging.warning(f'Visualization: expected positive value for ignore_index, got {dontcare_val}.'
                        f'Will be overridden to 255 during visualization only. Problems may occur.')

    # overwrite dontcare values in label if loss is not lovasz or crossentropy. FIXME: hacky fix.
    dontcare2backgr = False
    if loss_fn not in ['Lovasz', 'CrossEntropy', 'OhemCrossEntropy']:
        dontcare2backgr = True
        logging.warning(f'Dontcare is not implemented for loss function "{loss_fn}". '
                        f'Dontcare values ({dontcare_val}) in label will be replaced with background value (0)')

    # Will check if batch size needs to be a lower value only if cropping samples during training
    calc_eval_bs = True if crop_size else False

    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, model_name, criterion, optimizer, lr_scheduler, device, gpu_devices_dict = \
        net(model_name=model_name,
            num_bands=num_bands,
            num_channels=num_classes_corrected,
            dontcare_val=dontcare_val,
            num_devices=num_devices,
            train_state_dict_path=train_state_dict_path,
            pretrained=pretrained,
            dropout_prob=dropout_prob,
            loss_fn=loss_fn,
            class_weights=class_weights,
            optimizer=optimizer,
            net_params=params,
            conc_point=conc_point,
            coordconv_params=coordconv_params)

    logging.info(f'Instantiated {model_name} model with {num_classes_corrected} output channels.\n')

    logging.info(f'Creating dataloaders from data in {samples_folder}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_dataloader(samples_folder=samples_folder,
                                                                       batch_size=batch_size,
                                                                       eval_batch_size=eval_batch_size,
                                                                       gpu_devices_dict=gpu_devices_dict,
                                                                       sample_size=samples_size,
                                                                       dontcare_val=dontcare_val,
                                                                       crop_size=crop_size,
                                                                       meta_map=meta_map,
                                                                       num_bands=num_bands,
                                                                       BGR_to_RGB=BGR_to_RGB,
                                                                       scale=scale,
                                                                       params=params,
                                                                       dontcare2backgr=dontcare2backgr,
                                                                       calc_eval_bs=calc_eval_bs,
                                                                       debug=debug)


    # mlflow tracking path + parameters logging
    set_tracking_uri(mlflow_uri)
    set_experiment(experiment_name)
    start_run(run_name=run_name)
    log_params(params['training'])
    log_params(params['global'])
    log_params(params['sample'])

    if bucket_name:
        from utils.aws import download_s3_files
        bucket, bucket_output_path, output_path, data_path = download_s3_files(bucket_name=bucket_name,
                                                                               data_path=data_path,
                                                                               output_path=output_path)

    since = time.time()
    best_loss = 999
    last_vis_epoch = 0

    progress_log = output_path / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep', 'time'))  # Add header

    trn_log = InformationLogger('trn')
    val_log = InformationLogger('val')
    tst_log = InformationLogger('tst')
    filename = output_path.joinpath('checkpoint.pth.tar')

    # VISUALIZATION: generate pngs of inputs, labels and outputs
    if vis_batch_range is not None:
        # Make sure user-provided range is a tuple with 3 integers (start, finish, increment).
        # Check once for all visualization tasks.
        if not isinstance(vis_batch_range, list) and len(vis_batch_range) == 3 and all(isinstance(x, int)
                                                                                       for x in vis_batch_range):
            raise ValueError(f'Vis_batch_range expects three integers in a list: start batch, end batch, increment.'
                             f'Got {vis_batch_range}')
        vis_at_init_dataset = get_key_def('vis_at_init_dataset', params['visualization'], 'val')

        # Visualization at initialization. Visualize batch range before first eopch.
        if get_key_def('vis_at_init', params['visualization'], False):
            logging.info(f'Visualizing initialized model on batch range {vis_batch_range} '
                         f'from {vis_at_init_dataset} dataset...\n')
            vis_from_dataloader(vis_params=vis_params,
                                eval_loader=val_dataloader if vis_at_init_dataset == 'val' else tst_dataloader,
                                model=model,
                                ep_num=0,
                                output_path=output_path,
                                dataset=vis_at_init_dataset,
                                scale=scale,
                                device=device,
                                vis_batch_range=vis_batch_range)

    for epoch in range(0, num_epochs):
        logging.info(f'\nEpoch {epoch}/{num_epochs - 1}\n{"-" * 20}')

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes_corrected,
                           batch_size=batch_size,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           device=device,
                           scale=scale,
                           vis_params=vis_params,
                           debug=debug)
        trn_log.add_values(trn_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(eval_loader=val_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes_corrected,
                                batch_size=batch_size,
                                ep_idx=epoch,
                                progress_log=progress_log,
                                batch_metrics=batch_metrics,
                                dataset='val',
                                device=device,
                                scale=scale,
                                vis_params=vis_params,
                                debug=debug)
        val_loss = val_report['loss'].avg
        if batch_metrics is not None:
            val_log.add_values(val_report, epoch)
        else:
            val_log.add_values(val_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            logging.info("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict() if num_devices > 1 else model.state_dict()
            torch.save({'epoch': epoch,
                        'params': params,
                        'model': state_dict,
                        'best_loss': best_loss,
                        'optimizer': optimizer.state_dict()}, filename)
            if bucket_name:
                bucket_filename = bucket_output_path.joinpath('checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

            # VISUALIZATION: generate pngs of img samples, labels and outputs as alternative to follow training
            if vis_batch_range is not None and vis_at_checkpoint and epoch - last_vis_epoch >= ep_vis_min_thresh:
                if last_vis_epoch == 0:
                    logging.info(f'Visualizing with {vis_at_ckpt_dataset} dataset samples on checkpointed model for'
                                 f'batches in range {vis_batch_range}')
                vis_from_dataloader(vis_params=vis_params,
                                    eval_loader=val_dataloader if vis_at_ckpt_dataset == 'val' else tst_dataloader,
                                    model=model,
                                    ep_num=epoch+1,
                                    output_path=output_path,
                                    dataset=vis_at_ckpt_dataset,
                                    scale=scale,
                                    device=device,
                                    vis_batch_range=vis_batch_range)
                last_vis_epoch = epoch

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now, params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        logging.info(f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s')

    # load checkpoint model and evaluate it on test dataset.
    if num_epochs > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(eval_loader=tst_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes_corrected,
                                batch_size=batch_size,
                                ep_idx=num_epochs,
                                progress_log=progress_log,
                                batch_metrics=batch_metrics,
                                dataset='tst',
                                scale=scale,
                                vis_params=vis_params,
                                device=device)
        tst_log.add_values(tst_report, num_epochs)

        if bucket_name:
            bucket_filename = bucket_output_path.joinpath('last_epoch.pth.tar')
            bucket.upload_file("output.txt", bucket_output_path.joinpath(f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    log_params({'checkpoint path': filename})
    logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Exemple #25
0
def main(params, config_path):
    """
    Function to train and validate a models for semantic segmentation or classification.
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.

    """
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(
            f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.'
        )

    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    num_classes = params['global']['num_classes']
    task = params['global']['task']
    assert task == 'segmentation', f"The task should be segmentation. The provided value is {task}"
    num_classes_corrected = num_classes + 1  # + 1 for background # FIXME temporary patch for num_classes problem.

    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, checkpoint, model_name = net(
        params,
        num_classes_corrected)  # pretrained could become a yaml parameter.
    tqdm.write(
        f'Instantiated {model_name} model with {num_classes_corrected} output channels.\n'
    )
    bucket_name = params['global']['bucket_name']
    data_path = params['global']['data_path']
    assert Path(data_path).is_dir(), f'Could not locate data path {data_path}'

    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = params['sample']['sampling']['map']
    num_bands = params['global']['number_of_bands']
    samples_folder_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'  # FIXME: preferred name structure? document!
    samples_folder = Path(data_path).joinpath(
        samples_folder_name) if task == 'segmentation' else Path(data_path)

    modelname = config_path.stem
    output_path = Path(samples_folder).joinpath('model') / modelname
    if output_path.is_dir():
        output_path = Path(str(output_path) + '_' + now)
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))
    tqdm.write(f'Model and log files will be saved to: {output_path}\n\n')
    task = params['global']['task']
    batch_size = params['training']['batch_size']

    if bucket_name:
        bucket, bucket_output_path, output_path, data_path = download_s3_files(
            bucket_name=bucket_name,
            data_path=data_path,
            output_path=output_path)

    since = time.time()
    best_loss = 999
    last_vis_epoch = 0

    progress_log = Path(output_path) / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(
            tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep',
                     'time'))  # Add header

    trn_log = InformationLogger(output_path, 'trn')
    val_log = InformationLogger(output_path, 'val')
    tst_log = InformationLogger(output_path, 'tst')

    num_devices = params['global']['num_gpus']
    assert num_devices is not None and num_devices >= 0, "missing mandatory num gpus parameter"
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(
        num_devices) if torch.cuda.is_available() else []
    num_devices = len(lst_device_ids) if lst_device_ids else 0
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.
                          is_available() and lst_device_ids else 'cpu')
    print(
        f"Number of cuda devices requested: {params['global']['num_gpus']}. Cuda devices available: {lst_device_ids}\n"
    )
    if num_devices == 1:
        print(f"Using Cuda device {lst_device_ids[0]}\n")
    elif num_devices > 1:
        print(
            f"Using data parallel on devices: {str(lst_device_ids)[1:-1]}. Main device: {lst_device_ids[0]}\n"
        )  # TODO: why are we showing indices [1:-1] for lst_device_ids?
        try:  # FIXME: For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
            model = nn.DataParallel(
                model, device_ids=lst_device_ids
            )  # DataParallel adds prefix 'module.' to state_dict keys
        except AssertionError:
            warnings.warn(
                f"Unable to use devices {lst_device_ids}. Trying devices {list(range(len(lst_device_ids)))}"
            )
            device = torch.device('cuda:0')
            lst_device_ids = range(len(lst_device_ids))
            model = nn.DataParallel(
                model, device_ids=lst_device_ids
            )  # DataParallel adds prefix 'module.' to state_dict keys

    else:
        warnings.warn(
            f"No Cuda device available. This process will only run on CPU\n")

    tqdm.write(f'Creating dataloaders from data in {samples_folder}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_dataloader(
        data_path=data_path,
        batch_size=batch_size,
        task=task,
        num_devices=num_devices,
        params=params,
        samples_folder=samples_folder)

    tqdm.write(
        f'Setting model, criterion, optimizer and learning rate scheduler...\n'
    )
    model, criterion, optimizer, lr_scheduler = set_hyperparameters(
        params, num_classes_corrected, model, checkpoint)

    criterion = criterion.to(device)
    try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
        model.to(device)
    except RuntimeError:
        warnings.warn(f"Unable to use device. Trying device 0...\n")
        device = torch.device(f'cuda:0' if torch.cuda.is_available()
                              and lst_device_ids else 'cpu')
        model.to(device)

    filename = os.path.join(output_path, 'checkpoint.pth.tar')

    # VISUALIZATION: generate pngs of inputs, labels and outputs
    vis_batch_range = get_key_def('vis_batch_range', params['visualization'],
                                  None)
    if vis_batch_range is not None:
        # Make sure user-provided range is a tuple with 3 integers (start, finish, increment). Check once for all visualization tasks.
        assert isinstance(vis_batch_range,
                          list) and len(vis_batch_range) == 3 and all(
                              isinstance(x, int) for x in vis_batch_range)
        vis_at_init = get_key_def('vis_at_init', params['visualization'],
                                  False)
        vis_at_init_dataset = get_key_def('vis_at_init_dataset',
                                          params['visualization'], 'val')
        if vis_at_init:
            tqdm.write(
                f'Visualizing initialized model on batch range {vis_batch_range} from {vis_at_init_dataset} dataset...\n'
            )
            vis_from_dataloader(
                params=params,
                eval_loader=val_dataloader
                if vis_at_init_dataset == 'val' else tst_dataloader,
                model=model,
                ep_num=0,
                output_path=output_path,
                dataset=vis_at_init_dataset,
                device=device,
                vis_batch_range=vis_batch_range)

    for epoch in range(0, params['training']['num_epochs']):
        print(
            f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}'
        )

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes_corrected,
                           batch_size=batch_size,
                           task=task,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           vis_params=params,
                           device=device,
                           debug=debug)
        trn_log.add_values(trn_report,
                           epoch,
                           ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(
            eval_loader=val_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes_corrected,
            batch_size=batch_size,
            task=task,
            ep_idx=epoch,
            progress_log=progress_log,
            vis_params=params,
            batch_metrics=params['training']['batch_metrics'],
            dataset='val',
            device=device,
            debug=debug)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch)
        else:
            val_log.add_values(val_report,
                               epoch,
                               ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            tqdm.write("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict(
            ) if num_devices > 1 else model.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'arch': model_name,
                    'model': state_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict()
                }, filename)

            if bucket_name:
                bucket_filename = os.path.join(bucket_output_path,
                                               'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

            # VISUALIZATION: generate png of test samples, labels and outputs for visualisation to follow training performance
            vis_at_checkpoint = get_key_def('vis_at_checkpoint',
                                            params['visualization'], False)
            ep_vis_min_thresh = get_key_def('vis_at_ckpt_min_ep_diff',
                                            params['visualization'], 4)
            vis_at_ckpt_dataset = get_key_def('vis_at_ckpt_dataset',
                                              params['visualization'], 'val')
            if vis_batch_range is not None and vis_at_checkpoint and epoch - last_vis_epoch >= ep_vis_min_thresh:
                if last_vis_epoch == 0:
                    tqdm.write(
                        f'Visualizing with {vis_at_ckpt_dataset} dataset samples on checkpointed model for batches {vis_batch_range}'
                    )
                vis_from_dataloader(
                    params=params,
                    eval_loader=val_dataloader
                    if vis_at_ckpt_dataset == 'val' else tst_dataloader,
                    model=model,
                    ep_num=epoch + 1,
                    output_path=output_path,
                    dataset=vis_at_ckpt_dataset,
                    device=device,
                    vis_batch_range=vis_batch_range)
                last_vis_epoch = epoch

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now,
                                params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        print(
            f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s'
        )

    # load checkpoint model and evaluate it on test dataset.
    if int(
            params['training']['num_epochs']
    ) > 0:  #if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(
            eval_loader=tst_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes_corrected,
            batch_size=batch_size,
            task=task,
            ep_idx=params['training']['num_epochs'],
            progress_log=progress_log,
            vis_params=params,
            batch_metrics=params['training']['batch_metrics'],
            dataset='tst',
            device=device)
        tst_log.add_values(tst_report, params['training']['num_epochs'])

        if bucket_name:
            bucket_filename = os.path.join(bucket_output_path,
                                           'last_epoch.pth.tar')
            bucket.upload_file(
                "output.txt",
                os.path.join(bucket_output_path, f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Exemple #26
0
def main(params, config_path):
    """
    Function to train and validate a models for semantic segmentation or classification.
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.

    """
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(
            f'Debug mode activated. Some debug functions may cause delays in execution.'
        )

    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    num_classes = params['global']['num_classes']
    task = params['global']['task']
    batch_size = params['training']['batch_size']
    assert task == 'classification', f"The task should be classification. The provided value is {task}"

    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, checkpoint, model_name = net(
        params, num_classes)  # pretrained could become a yaml parameter.
    tqdm.write(
        f'Instantiated {model_name} model with {num_classes} output channels.\n'
    )
    bucket_name = params['global']['bucket_name']
    data_path = params['global']['data_path']

    modelname = config_path.stem
    output_path = Path(data_path).joinpath('model') / modelname
    if output_path.is_dir():
        output_path = Path(str(output_path) + '_' + now)
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))
    tqdm.write(f'Model and log files will be saved to: {output_path}\n\n')

    if bucket_name:
        bucket, bucket_output_path, output_path, data_path = download_s3_files(
            bucket_name=bucket_name,
            data_path=data_path,
            output_path=output_path,
            num_classes=num_classes)

    elif not bucket_name:
        get_local_classes(num_classes, data_path, output_path)

    since = time.time()
    best_loss = 999

    progress_log = Path(output_path) / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(
            tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep',
                     'time'))  # Add header

    trn_log = InformationLogger('trn')
    val_log = InformationLogger('val')
    tst_log = InformationLogger('tst')

    num_devices = params['global']['num_gpus']
    assert num_devices is not None and num_devices >= 0, "missing mandatory num gpus parameter"
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(
        num_devices) if torch.cuda.is_available() else []
    num_devices = len(lst_device_ids) if lst_device_ids else 0
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.
                          is_available() and lst_device_ids else 'cpu')
    print(
        f"Number of cuda devices requested: {params['global']['num_gpus']}. Cuda devices available: {lst_device_ids}\n"
    )
    if num_devices == 1:
        print(f"Using Cuda device {lst_device_ids[0]}\n")
    elif num_devices > 1:
        print(
            f"Using data parallel on devices: {str(lst_device_ids)[1:-1]}. Main device: {lst_device_ids[0]}\n"
        )  # TODO: why are we showing indices [1:-1] for lst_device_ids?
        try:  # TODO: For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
            model = nn.DataParallel(
                model, device_ids=lst_device_ids
            )  # DataParallel adds prefix 'module.' to state_dict keys
        except AssertionError:
            warnings.warn(
                f"Unable to use devices {lst_device_ids}. Trying devices {list(range(len(lst_device_ids)))}"
            )
            device = torch.device('cuda:0')
            lst_device_ids = range(len(lst_device_ids))
            model = nn.DataParallel(
                model, device_ids=lst_device_ids
            )  # DataParallel adds prefix 'module.' to state_dict keys

    else:
        warnings.warn(
            f"No Cuda device available. This process will only run on CPU\n")

    tqdm.write(f'Creating dataloaders from data in {Path(data_path)}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_classif_dataloader(
        data_path=data_path,
        batch_size=batch_size,
        num_devices=num_devices,
    )

    tqdm.write(
        f'Setting model, criterion, optimizer and learning rate scheduler...\n'
    )
    model, criterion, optimizer, lr_scheduler = set_hyperparameters(
        params, num_classes, model, checkpoint)

    criterion = criterion.to(device)
    try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
        model.to(device)
    except RuntimeError:
        warnings.warn(f"Unable to use device. Trying device 0...\n")
        device = torch.device(f'cuda:0' if torch.cuda.is_available()
                              and lst_device_ids else 'cpu')
        model.to(device)

    filename = os.path.join(output_path, 'checkpoint.pth.tar')

    for epoch in range(0, params['training']['num_epochs']):
        print(
            f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}'
        )

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes,
                           batch_size=batch_size,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           device=device,
                           debug=debug)
        trn_log.add_values(trn_report,
                           epoch,
                           ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(
            eval_loader=val_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            batch_size=batch_size,
            ep_idx=epoch,
            progress_log=progress_log,
            batch_metrics=params['training']['batch_metrics'],
            dataset='val',
            device=device,
            debug=debug)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch, ignore=['iou'])
        else:
            val_log.add_values(val_report,
                               epoch,
                               ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            tqdm.write("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict(
            ) if num_devices > 1 else model.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'arch': model_name,
                    'model': state_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict()
                }, filename)

            if bucket_name:
                bucket_filename = os.path.join(bucket_output_path,
                                               'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now,
                                params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        print(
            f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s'
        )

    # load checkpoint model and evaluate it on test dataset.
    if int(
            params['training']['num_epochs']
    ) > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(
            eval_loader=tst_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            batch_size=batch_size,
            ep_idx=params['training']['num_epochs'],
            progress_log=progress_log,
            batch_metrics=params['training']['batch_metrics'],
            dataset='tst',
            device=device)
        tst_log.add_values(tst_report,
                           params['training']['num_epochs'],
                           ignore=['iou'])

        if bucket_name:
            bucket_filename = os.path.join(bucket_output_path,
                                           'last_epoch.pth.tar')
            bucket.upload_file(
                "output.txt",
                os.path.join(bucket_output_path, f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Exemple #27
0
def evaluation(eval_loader,
               model,
               criterion,
               num_classes,
               batch_size,
               task,
               ep_idx,
               progress_log,
               vis_params,
               batch_metrics=None,
               dataset='val',
               device=None,
               debug=False):
    """
    Evaluate the model and return the updated metrics
    :param eval_loader: data loader
    :param model: model to evaluate
    :param criterion: loss criterion
    :param num_classes: number of classes
    :param batch_size: number of samples to process simultaneously
    :param task: segmentation or classification
    :param ep_idx: epoch index (for hypertrainer log)
    :param progress_log: progress log file (for hypertrainer log)
    :param batch_metrics: (int) Metrics computed every (int) batches. If left blank, will not perform metrics.
    :param dataset: (str) 'val or 'tst'
    :param device: device used by pytorch (cpu ou cuda)
    :return: (dict) eval_metrics
    """
    eval_metrics = create_metrics_dict(num_classes)
    model.eval()
    vis_at_eval = get_key_def('vis_at_evaluation', vis_params['visualization'],
                              False)
    vis_batch_range = get_key_def('vis_batch_range',
                                  vis_params['visualization'], None)
    min_vis_batch, max_vis_batch, increment = vis_batch_range

    with tqdm(eval_loader,
              dynamic_ncols=True,
              desc=f'Iterating {dataset} batches with {device.type}') as _tqdm:
        for batch_index, data in enumerate(_tqdm):
            progress_log.open('a', buffering=1).write(
                tsv_line(ep_idx, dataset, batch_index, len(eval_loader),
                         time.time()))

            with torch.no_grad():
                inputs = data['sat_img'].to(device)
                labels = data['map_img'].to(device)
                labels_flatten = flatten_labels(labels)

                outputs = model(inputs)
                if isinstance(outputs, OrderedDict):
                    outputs = outputs['out']

                if vis_batch_range is not None and vis_at_eval and batch_index in range(
                        min_vis_batch, max_vis_batch, increment):
                    vis_path = progress_log.parent.joinpath('visualization')
                    if ep_idx == 0 and batch_index == min_vis_batch:
                        tqdm.write(
                            f'Visualizing on {dataset} outputs for batches in range {vis_batch_range}. All '
                            f'images will be saved to {vis_path}\n')
                    vis_from_batch(params,
                                   inputs,
                                   outputs,
                                   batch_index=batch_index,
                                   vis_path=vis_path,
                                   labels=labels,
                                   dataset=dataset,
                                   ep_num=ep_idx + 1)

                outputs_flatten = flatten_outputs(outputs, num_classes)

                loss = criterion(outputs, labels)

                eval_metrics['loss'].update(loss.item(), batch_size)

                if (dataset == 'val') and (batch_metrics is not None):
                    # Compute metrics every n batches. Time consuming.
                    assert batch_metrics <= len(_tqdm), f"Batch_metrics ({batch_metrics} is smaller than batch size " \
                        f"{len(_tqdm)}. Metrics in validation loop won't be computed"
                    if (
                            batch_index + 1
                    ) % batch_metrics == 0:  # +1 to skip val loop at very beginning
                        a, segmentation = torch.max(outputs_flatten, dim=1)
                        eval_metrics = report_classification(
                            segmentation,
                            labels_flatten,
                            batch_size,
                            eval_metrics,
                            ignore_index=get_key_def("ignore_index",
                                                     params["training"], None))
                elif dataset == 'tst':
                    a, segmentation = torch.max(outputs_flatten, dim=1)
                    eval_metrics = report_classification(
                        segmentation,
                        labels_flatten,
                        batch_size,
                        eval_metrics,
                        ignore_index=get_key_def("ignore_index",
                                                 params["training"], None))

                _tqdm.set_postfix(
                    OrderedDict(dataset=dataset,
                                loss=f'{eval_metrics["loss"].avg:.4f}'))

                if debug and device.type == 'cuda':
                    res, mem = gpu_stats(device=device.index)
                    _tqdm.set_postfix(
                        OrderedDict(
                            device=device,
                            gpu_perc=f'{res.gpu} %',
                            gpu_RAM=
                            f'{mem.used/(1024**2):.0f}/{mem.total/(1024**2):.0f} MiB'
                        ))

    print(f"{dataset} Loss: {eval_metrics['loss'].avg}")
    if batch_metrics is not None:
        print(f"{dataset} precision: {eval_metrics['precision'].avg}")
        print(f"{dataset} recall: {eval_metrics['recall'].avg}")
        print(f"{dataset} fscore: {eval_metrics['fscore'].avg}")

    return eval_metrics
Exemple #28
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path', net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'], True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)
    dontcare_val = get_key_def("ignore_index", net_params["training"], -1)
    num_devices = net_params['global']['num_gpus']

    if dontcare_val == 0:
        warnings.warn("The 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero;"
                      " all valid class indices should be consecutive, and start at 0. The 'dontcare' value"
                      " will be remapped to -1 while loading the dataset, and inside the config from now on.")
        net_params["training"]["ignore_index"] = -1

    # TODO: find a way to maybe implement it in classification one day
    if 'concatenate_depth' in net_params['global']:
        # Read the concatenation point
        conc_point = net_params['global']['concatenate_depth']

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False, progress=True, num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=pretrained, progress=True)
            classifier = list(model.classifier.children())
            model.classifier = nn.Sequential(*classifier[:-1])
            model.classifier.add_module('4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1)))
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            print('Testing with 4 bands, concatenating at {}.'.format(conc_point))

            model = models.segmentation.deeplabv3_resnet101(pretrained=pretrained, progress=True)

            if conc_point=='baseline':
                conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                depth = np.expand_dims(conv1[:, 1, ...], axis=1)  # reuse green weights for infrared.
                conv1 = np.append(conv1, depth, axis=1)
                conv1 = torch.from_numpy(conv1).float()
                model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1))
                )
            else:
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                        '4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1))
                )
                ###################
                #conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                #depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
                #conv1 = np.append(conv1, depth, axis=1)
                #conv1 = torch.from_numpy(conv1).float()
                #model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                ###################
                model = LayersEnsemble(model, conc_point=conc_point)

    elif model_name in lm_smp.keys():
        lsmp = lm_smp[model_name]
        # TODO: add possibility of our own weights
        lsmp['params']['encoder_weights'] = "imagenet" if 'pretrained' in model_name.split("_") else None
        lsmp['params']['in_channels'] = num_bands
        lsmp['params']['classes'] = num_channels
        lsmp['params']['activation'] = None

        model = lsmp['fct'](**lsmp['params'])


    else:
        raise ValueError(f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'], False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'], True)
        normalized = get_key_def('coordconv_normalized', net_params['global'], True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel', net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model, centered=centered, normalized=normalized, noise=noise,
                                                radius_channel=radius_channel, scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)

        return model, checkpoint, model_name

    else:

        if train_state_dict_path is not None:
            assert Path(train_state_dict_path).is_file(), f'Could not locate checkpoint at {train_state_dict_path}'
            checkpoint = load_checkpoint(train_state_dict_path)
        else:
            checkpoint = None
        assert num_devices is not None and num_devices >= 0, "missing mandatory num gpus parameter"
        # list of GPU devices that are available and unused. If no GPUs, returns empty list
        lst_device_ids = get_device_ids(num_devices) if torch.cuda.is_available() else []
        num_devices = len(lst_device_ids) if lst_device_ids else 0
        device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.is_available() and lst_device_ids else 'cpu')
        print(f"Number of cuda devices requested: {net_params['global']['num_gpus']}. Cuda devices available: {lst_device_ids}\n")
        if num_devices == 1:
            print(f"Using Cuda device {lst_device_ids[0]}\n")
        elif num_devices > 1:
            print(f"Using data parallel on devices: {str(lst_device_ids)[1:-1]}. Main device: {lst_device_ids[0]}\n") # TODO: why are we showing indices [1:-1] for lst_device_ids?
            try:  # For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
                model = nn.DataParallel(model,
                                        device_ids=lst_device_ids)  # DataParallel adds prefix 'module.' to state_dict keys
            except AssertionError:
                warnings.warn(f"Unable to use devices {lst_device_ids}. Trying devices {list(range(len(lst_device_ids)))}")
                device = torch.device('cuda:0')
                lst_device_ids = range(len(lst_device_ids))
                model = nn.DataParallel(model,
                                        device_ids=lst_device_ids)  # DataParallel adds prefix 'module.' to state_dict keys
        else:
            warnings.warn(f"No Cuda device available. This process will only run on CPU\n")
        tqdm.write(f'Setting model, criterion, optimizer and learning rate scheduler...\n')
        try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
            model.to(device)
        except RuntimeError:
            warnings.warn(f"Unable to use device. Trying device 0...\n")
            device = torch.device(f'cuda:0' if torch.cuda.is_available() and lst_device_ids else 'cpu')
            model.to(device)

        model, criterion, optimizer, lr_scheduler = set_hyperparameters(net_params, num_channels, model, checkpoint, dontcare_val)
        criterion = criterion.to(device)

        return model, model_name, criterion, optimizer, lr_scheduler
Exemple #29
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path',
                                        net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'],
                             True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout,
                                            dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False,
                                                  progress=True,
                                                  num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                            progress=True,
                                                            aux_loss=None)
            model.classifier = common.DeepLabHead(2048, num_channels)
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                            progress=True,
                                                            aux_loss=None)
            conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
            depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
            conv1 = np.append(conv1, depth, axis=1)
            conv1 = torch.from_numpy(conv1).float()
            model.backbone._modules['conv1'].weight = nn.Parameter(
                conv1, requires_grad=True)
            model.classifier = common.DeepLabHead(2048, num_channels)
    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'],
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'],
                               True)
        normalized = get_key_def('coordconv_normalized', net_params['global'],
                                 True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(
        ), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(
        ), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    else:
        checkpoint = None

    return model, checkpoint, model_name
Exemple #30
0
def compose_transforms(params, dataset, type='', ignore_index=None):
    """
    Function to compose the transformations to be applied on every batches.
    :param params: (dict) Parameters found in the yaml config file
    :param dataset: (str) One of 'trn', 'val', 'tst'
    :param type: (str) One of 'geometric', 'radiometric'
    :return: (obj) PyTorch's compose object of the transformations to be applied.
    """
    lst_trans = []
    scale = get_key_def('scale_data', params['global'], None)
    norm_mean = get_key_def('mean', params['training']['normalization'])
    norm_std = get_key_def('std', params['training']['normalization'])
    random_radiom_trim_range = get_key_def('random_radiom_trim_range',
                                           params['training']['augmentation'],
                                           None)

    if dataset == 'trn':

        if type == 'radiometric':
            noise = get_key_def('noise', params['training']['augmentation'],
                                None)

            if random_radiom_trim_range:  # Contrast stretching
                lst_trans.append(
                    RadiometricTrim(random_range=random_radiom_trim_range)
                )  # FIXME: test this. Assure compatibility with CRIM devs (don't trim metadata)

            if noise:
                raise NotImplementedError

        elif type == 'geometric':
            geom_scale_range = get_key_def('geom_scale_range',
                                           params['training']['augmentation'],
                                           None)
            hflip = get_key_def('hflip_prob',
                                params['training']['augmentation'], None)
            rotate_prob = get_key_def('rotate_prob',
                                      params['training']['augmentation'], None)
            rotate_limit = get_key_def('rotate_limit',
                                       params['training']['augmentation'],
                                       None)
            crop_size = get_key_def('target_size', params['training'], None)

            if geom_scale_range:  # TODO: test this.
                lst_trans.append(GeometricScale(range=geom_scale_range))

            if hflip:
                lst_trans.append(
                    HorizontalFlip(
                        prob=params['training']['augmentation']['hflip_prob']))

            if rotate_limit and rotate_prob:
                lst_trans.append(
                    RandomRotationTarget(limit=rotate_limit,
                                         prob=rotate_prob,
                                         ignore_index=ignore_index))

            if crop_size:
                lst_trans.append(
                    RandomCrop(sample_size=crop_size,
                               ignore_index=ignore_index))

    if type == 'totensor':
        if not dataset == 'trn' and random_radiom_trim_range:  # Contrast stretching at eval. Use mean of provided range
            RadiometricTrim.input_checker(
                random_radiom_trim_range
            )  # Assert range is number or 2 element sequence
            if isinstance(random_radiom_trim_range, numbers.Number):
                trim_at_eval = random_radiom_trim_range
            else:
                trim_at_eval = round((random_radiom_trim_range[-1] -
                                      random_radiom_trim_range[0]) / 2, 1)
            lst_trans.append(
                RadiometricTrim(random_range=[trim_at_eval, trim_at_eval]))
        if scale:
            lst_trans.append(Scale(
                scale))  # TODO: assert coherence with below normalization

        if norm_mean and norm_std:
            lst_trans.append(
                Normalize(mean=params['training']['normalization']['mean'],
                          std=params['training']['normalization']['std']))

        lst_trans.append(
            ToTensorTarget(get_key_def('BGR_to_RGB', params['global'], False))
        )  # Send channels first, convert numpy array to torch tensor

    return transforms.Compose(lst_trans)