Esempio n. 1
0
def main(params, config_path):
    """
    Function to train and validate a models for semantic segmentation or classification.
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.

    """
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(
            f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.'
        )

    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    num_classes = params['global']['num_classes']
    task = params['global']['task']
    assert task == 'segmentation', f"The task should be segmentation. The provided value is {task}"
    num_classes_corrected = num_classes + 1  # + 1 for background # FIXME temporary patch for num_classes problem.

    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, checkpoint, model_name = net(
        params,
        num_classes_corrected)  # pretrained could become a yaml parameter.
    tqdm.write(
        f'Instantiated {model_name} model with {num_classes_corrected} output channels.\n'
    )
    bucket_name = params['global']['bucket_name']
    data_path = params['global']['data_path']
    assert Path(data_path).is_dir(), f'Could not locate data path {data_path}'

    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = params['sample']['sampling']['map']
    num_bands = params['global']['number_of_bands']
    samples_folder_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'  # FIXME: preferred name structure? document!
    samples_folder = Path(data_path).joinpath(
        samples_folder_name) if task == 'segmentation' else Path(data_path)

    modelname = config_path.stem
    output_path = Path(samples_folder).joinpath('model') / modelname
    if output_path.is_dir():
        output_path = Path(str(output_path) + '_' + now)
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))
    tqdm.write(f'Model and log files will be saved to: {output_path}\n\n')
    task = params['global']['task']
    batch_size = params['training']['batch_size']

    if bucket_name:
        bucket, bucket_output_path, output_path, data_path = download_s3_files(
            bucket_name=bucket_name,
            data_path=data_path,
            output_path=output_path)

    since = time.time()
    best_loss = 999
    last_vis_epoch = 0

    progress_log = Path(output_path) / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(
            tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep',
                     'time'))  # Add header

    trn_log = InformationLogger(output_path, 'trn')
    val_log = InformationLogger(output_path, 'val')
    tst_log = InformationLogger(output_path, 'tst')

    num_devices = params['global']['num_gpus']
    assert num_devices is not None and num_devices >= 0, "missing mandatory num gpus parameter"
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(
        num_devices) if torch.cuda.is_available() else []
    num_devices = len(lst_device_ids) if lst_device_ids else 0
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.
                          is_available() and lst_device_ids else 'cpu')
    print(
        f"Number of cuda devices requested: {params['global']['num_gpus']}. Cuda devices available: {lst_device_ids}\n"
    )
    if num_devices == 1:
        print(f"Using Cuda device {lst_device_ids[0]}\n")
    elif num_devices > 1:
        print(
            f"Using data parallel on devices: {str(lst_device_ids)[1:-1]}. Main device: {lst_device_ids[0]}\n"
        )  # TODO: why are we showing indices [1:-1] for lst_device_ids?
        try:  # FIXME: For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
            model = nn.DataParallel(
                model, device_ids=lst_device_ids
            )  # DataParallel adds prefix 'module.' to state_dict keys
        except AssertionError:
            warnings.warn(
                f"Unable to use devices {lst_device_ids}. Trying devices {list(range(len(lst_device_ids)))}"
            )
            device = torch.device('cuda:0')
            lst_device_ids = range(len(lst_device_ids))
            model = nn.DataParallel(
                model, device_ids=lst_device_ids
            )  # DataParallel adds prefix 'module.' to state_dict keys

    else:
        warnings.warn(
            f"No Cuda device available. This process will only run on CPU\n")

    tqdm.write(f'Creating dataloaders from data in {samples_folder}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_dataloader(
        data_path=data_path,
        batch_size=batch_size,
        task=task,
        num_devices=num_devices,
        params=params,
        samples_folder=samples_folder)

    tqdm.write(
        f'Setting model, criterion, optimizer and learning rate scheduler...\n'
    )
    model, criterion, optimizer, lr_scheduler = set_hyperparameters(
        params, num_classes_corrected, model, checkpoint)

    criterion = criterion.to(device)
    try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
        model.to(device)
    except RuntimeError:
        warnings.warn(f"Unable to use device. Trying device 0...\n")
        device = torch.device(f'cuda:0' if torch.cuda.is_available()
                              and lst_device_ids else 'cpu')
        model.to(device)

    filename = os.path.join(output_path, 'checkpoint.pth.tar')

    # VISUALIZATION: generate pngs of inputs, labels and outputs
    vis_batch_range = get_key_def('vis_batch_range', params['visualization'],
                                  None)
    if vis_batch_range is not None:
        # Make sure user-provided range is a tuple with 3 integers (start, finish, increment). Check once for all visualization tasks.
        assert isinstance(vis_batch_range,
                          list) and len(vis_batch_range) == 3 and all(
                              isinstance(x, int) for x in vis_batch_range)
        vis_at_init = get_key_def('vis_at_init', params['visualization'],
                                  False)
        vis_at_init_dataset = get_key_def('vis_at_init_dataset',
                                          params['visualization'], 'val')
        if vis_at_init:
            tqdm.write(
                f'Visualizing initialized model on batch range {vis_batch_range} from {vis_at_init_dataset} dataset...\n'
            )
            vis_from_dataloader(
                params=params,
                eval_loader=val_dataloader
                if vis_at_init_dataset == 'val' else tst_dataloader,
                model=model,
                ep_num=0,
                output_path=output_path,
                dataset=vis_at_init_dataset,
                device=device,
                vis_batch_range=vis_batch_range)

    for epoch in range(0, params['training']['num_epochs']):
        print(
            f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}'
        )

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes_corrected,
                           batch_size=batch_size,
                           task=task,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           vis_params=params,
                           device=device,
                           debug=debug)
        trn_log.add_values(trn_report,
                           epoch,
                           ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(
            eval_loader=val_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes_corrected,
            batch_size=batch_size,
            task=task,
            ep_idx=epoch,
            progress_log=progress_log,
            vis_params=params,
            batch_metrics=params['training']['batch_metrics'],
            dataset='val',
            device=device,
            debug=debug)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch)
        else:
            val_log.add_values(val_report,
                               epoch,
                               ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            tqdm.write("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict(
            ) if num_devices > 1 else model.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'arch': model_name,
                    'model': state_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict()
                }, filename)

            if bucket_name:
                bucket_filename = os.path.join(bucket_output_path,
                                               'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

            # VISUALIZATION: generate png of test samples, labels and outputs for visualisation to follow training performance
            vis_at_checkpoint = get_key_def('vis_at_checkpoint',
                                            params['visualization'], False)
            ep_vis_min_thresh = get_key_def('vis_at_ckpt_min_ep_diff',
                                            params['visualization'], 4)
            vis_at_ckpt_dataset = get_key_def('vis_at_ckpt_dataset',
                                              params['visualization'], 'val')
            if vis_batch_range is not None and vis_at_checkpoint and epoch - last_vis_epoch >= ep_vis_min_thresh:
                if last_vis_epoch == 0:
                    tqdm.write(
                        f'Visualizing with {vis_at_ckpt_dataset} dataset samples on checkpointed model for batches {vis_batch_range}'
                    )
                vis_from_dataloader(
                    params=params,
                    eval_loader=val_dataloader
                    if vis_at_ckpt_dataset == 'val' else tst_dataloader,
                    model=model,
                    ep_num=epoch + 1,
                    output_path=output_path,
                    dataset=vis_at_ckpt_dataset,
                    device=device,
                    vis_batch_range=vis_batch_range)
                last_vis_epoch = epoch

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now,
                                params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        print(
            f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s'
        )

    # load checkpoint model and evaluate it on test dataset.
    if int(
            params['training']['num_epochs']
    ) > 0:  #if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(
            eval_loader=tst_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes_corrected,
            batch_size=batch_size,
            task=task,
            ep_idx=params['training']['num_epochs'],
            progress_log=progress_log,
            vis_params=params,
            batch_metrics=params['training']['batch_metrics'],
            dataset='tst',
            device=device)
        tst_log.add_values(tst_report, params['training']['num_epochs'])

        if bucket_name:
            bucket_filename = os.path.join(bucket_output_path,
                                           'last_epoch.pth.tar')
            bucket.upload_file(
                "output.txt",
                os.path.join(bucket_output_path, f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Esempio n. 2
0
def main(params, config_path):
    """
    Function to train and validate a models for semantic segmentation or classification.
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.

    """
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(
            f'Debug mode activated. Some debug functions may cause delays in execution.'
        )

    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    num_classes = params['global']['num_classes']
    task = params['global']['task']
    batch_size = params['training']['batch_size']
    assert task == 'classification', f"The task should be classification. The provided value is {task}"

    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, checkpoint, model_name = net(
        params, num_classes)  # pretrained could become a yaml parameter.
    tqdm.write(
        f'Instantiated {model_name} model with {num_classes} output channels.\n'
    )
    bucket_name = params['global']['bucket_name']
    data_path = params['global']['data_path']

    modelname = config_path.stem
    output_path = Path(data_path).joinpath('model') / modelname
    if output_path.is_dir():
        output_path = Path(str(output_path) + '_' + now)
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))
    tqdm.write(f'Model and log files will be saved to: {output_path}\n\n')

    if bucket_name:
        bucket, bucket_output_path, output_path, data_path = download_s3_files(
            bucket_name=bucket_name,
            data_path=data_path,
            output_path=output_path,
            num_classes=num_classes)

    elif not bucket_name:
        get_local_classes(num_classes, data_path, output_path)

    since = time.time()
    best_loss = 999

    progress_log = Path(output_path) / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(
            tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep',
                     'time'))  # Add header

    trn_log = InformationLogger('trn')
    val_log = InformationLogger('val')
    tst_log = InformationLogger('tst')

    num_devices = params['global']['num_gpus']
    assert num_devices is not None and num_devices >= 0, "missing mandatory num gpus parameter"
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(
        num_devices) if torch.cuda.is_available() else []
    num_devices = len(lst_device_ids) if lst_device_ids else 0
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.
                          is_available() and lst_device_ids else 'cpu')
    print(
        f"Number of cuda devices requested: {params['global']['num_gpus']}. Cuda devices available: {lst_device_ids}\n"
    )
    if num_devices == 1:
        print(f"Using Cuda device {lst_device_ids[0]}\n")
    elif num_devices > 1:
        print(
            f"Using data parallel on devices: {str(lst_device_ids)[1:-1]}. Main device: {lst_device_ids[0]}\n"
        )  # TODO: why are we showing indices [1:-1] for lst_device_ids?
        try:  # TODO: For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
            model = nn.DataParallel(
                model, device_ids=lst_device_ids
            )  # DataParallel adds prefix 'module.' to state_dict keys
        except AssertionError:
            warnings.warn(
                f"Unable to use devices {lst_device_ids}. Trying devices {list(range(len(lst_device_ids)))}"
            )
            device = torch.device('cuda:0')
            lst_device_ids = range(len(lst_device_ids))
            model = nn.DataParallel(
                model, device_ids=lst_device_ids
            )  # DataParallel adds prefix 'module.' to state_dict keys

    else:
        warnings.warn(
            f"No Cuda device available. This process will only run on CPU\n")

    tqdm.write(f'Creating dataloaders from data in {Path(data_path)}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_classif_dataloader(
        data_path=data_path,
        batch_size=batch_size,
        num_devices=num_devices,
    )

    tqdm.write(
        f'Setting model, criterion, optimizer and learning rate scheduler...\n'
    )
    model, criterion, optimizer, lr_scheduler = set_hyperparameters(
        params, num_classes, model, checkpoint)

    criterion = criterion.to(device)
    try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
        model.to(device)
    except RuntimeError:
        warnings.warn(f"Unable to use device. Trying device 0...\n")
        device = torch.device(f'cuda:0' if torch.cuda.is_available()
                              and lst_device_ids else 'cpu')
        model.to(device)

    filename = os.path.join(output_path, 'checkpoint.pth.tar')

    for epoch in range(0, params['training']['num_epochs']):
        print(
            f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}'
        )

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes,
                           batch_size=batch_size,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           device=device,
                           debug=debug)
        trn_log.add_values(trn_report,
                           epoch,
                           ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(
            eval_loader=val_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            batch_size=batch_size,
            ep_idx=epoch,
            progress_log=progress_log,
            batch_metrics=params['training']['batch_metrics'],
            dataset='val',
            device=device,
            debug=debug)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch, ignore=['iou'])
        else:
            val_log.add_values(val_report,
                               epoch,
                               ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            tqdm.write("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict(
            ) if num_devices > 1 else model.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'arch': model_name,
                    'model': state_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict()
                }, filename)

            if bucket_name:
                bucket_filename = os.path.join(bucket_output_path,
                                               'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now,
                                params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        print(
            f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s'
        )

    # load checkpoint model and evaluate it on test dataset.
    if int(
            params['training']['num_epochs']
    ) > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(
            eval_loader=tst_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            batch_size=batch_size,
            ep_idx=params['training']['num_epochs'],
            progress_log=progress_log,
            batch_metrics=params['training']['batch_metrics'],
            dataset='tst',
            device=device)
        tst_log.add_values(tst_report,
                           params['training']['num_epochs'],
                           ignore=['iou'])

        if bucket_name:
            bucket_filename = os.path.join(bucket_output_path,
                                           'last_epoch.pth.tar')
            bucket.upload_file(
                "output.txt",
                os.path.join(bucket_output_path, f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Esempio n. 3
0
                        help='Path to parameters stored in yaml')
    parser.add_argument('-i',
                        '--input',
                        metavar='model_pth img_dir',
                        nargs=2,
                        help='model_path and image_dir')
    args = parser.parse_args()

    # if a yaml is inputted, get those parameters and get model state_dict to overwrite global parameters afterwards
    if args.param:
        input_params = read_parameters(args.param[0])
        model_ckpt = get_key_def('state_dict_path',
                                 input_params['inference'],
                                 expected_type=str)
        # load checkpoint
        checkpoint = load_checkpoint(model_ckpt)
        if 'params' in checkpoint.keys():
            params = checkpoint['params']
            # overwrite with inputted parameters
            compare_config_yamls(yaml1=params,
                                 yaml2=input_params,
                                 update_yaml1=True)
        else:
            warnings.warn(
                'No parameters found in checkpoint. Defaulting to parameters from inputted yaml.'
                'Use GDL version 1.3 or more.')
            params = input_params
        del checkpoint
        del input_params

    # elif input is a model checkpoint and an image directory, we'll rely on the yaml saved inside the model (pth.tar)
Esempio n. 4
0
def main(params, config_path):
    """
    Function to train and validate a model for semantic segmentation.

    Process
    -------
    1. Model is instantiated and checkpoint is loaded from path, if provided in
       `your_config.yaml`.
    2. GPUs are requested according to desired amount of `num_gpus` and
       available GPUs.
    3. If more than 1 GPU is requested, model is cast to DataParallel model
    4. Dataloaders are created with `create_dataloader()`
    5. Loss criterion, optimizer and learning rate are set with
       `set_hyperparameters()` as requested in `config.yaml`.
    5. Using these hyperparameters, the application will try to minimize the
       loss on the training data and evaluate every epoch on the validation
       data.
    6. For every epoch, the application shows and logs the loss on "trn" and
       "val" datasets.
    7. For every epoch (if `batch_metrics: 1`), the application shows and logs
       the accuracy, recall and f-score on "val" dataset. Those metrics are
       also computed on each class.
    8. At the end of the training process, the application shows and logs the
       accuracy, recall and f-score on "tst" dataset. Those metrics are also
       computed on each class.

    -------
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.
    """
    now = datetime.now().strftime("%Y-%m-%d_%H-%M")

    # MANDATORY PARAMETERS
    num_classes = get_key_def('num_classes', params['global'], expected_type=int)
    num_classes_corrected = num_classes + 1  # + 1 for background # FIXME temporary patch for num_classes problem.
    num_bands = get_key_def('number_of_bands', params['global'], expected_type=int)
    batch_size = get_key_def('batch_size', params['training'], expected_type=int)
    eval_batch_size = get_key_def('eval_batch_size', params['training'], expected_type=int, default=batch_size)
    num_epochs = get_key_def('num_epochs', params['training'], expected_type=int)
    model_name = get_key_def('model_name', params['global'], expected_type=str).lower()
    BGR_to_RGB = get_key_def('BGR_to_RGB', params['global'], expected_type=bool)

    # OPTIONAL PARAMETERS
    # basics
    debug = get_key_def('debug_mode', params['global'], default=False, expected_type=bool)
    task = get_key_def('task', params['global'], default='segmentation', expected_type=str)
    if not task == 'segmentation':
        raise ValueError(f"The task should be segmentation. The provided value is {task}")
    dontcare_val = get_key_def("ignore_index", params["training"], default=-1, expected_type=int)
    crop_size = get_key_def('target_size', params['training'], default=None, expected_type=int)
    batch_metrics = get_key_def('batch_metrics', params['training'], default=None, expected_type=int)
    meta_map = get_key_def("meta_map", params["global"], default=None)
    if meta_map and not Path(meta_map).is_file():
        raise FileNotFoundError(f'Couldn\'t locate {meta_map}')
    bucket_name = get_key_def('bucket_name', params['global'])  # AWS
    scale = get_key_def('scale_data', params['global'], default=[0, 1], expected_type=List)

    # model params
    loss_fn = get_key_def('loss_fn', params['training'], default='CrossEntropy', expected_type=str)
    class_weights = get_key_def('class_weights', params['training'], default=None, expected_type=Sequence)
    if class_weights:
        verify_weights(num_classes_corrected, class_weights)
    optimizer = get_key_def('optimizer', params['training'], default='adam', expected_type=str)
    pretrained = get_key_def('pretrained', params['training'], default=True, expected_type=bool)
    train_state_dict_path = get_key_def('state_dict_path', params['training'], default=None, expected_type=str)
    if train_state_dict_path and not Path(train_state_dict_path).is_file():
        raise FileNotFoundError(f'Could not locate pretrained checkpoint for training: {train_state_dict_path}')
    dropout_prob = get_key_def('dropout_prob', params['training'], default=None, expected_type=float)
    # Read the concatenation point
    # TODO: find a way to maybe implement it in classification one day
    conc_point = get_key_def('concatenate_depth', params['global'], None)

    # gpu parameters
    num_devices = get_key_def('num_gpus', params['global'], default=0, expected_type=int)
    if num_devices and not num_devices >= 0:
        raise ValueError("missing mandatory num gpus parameter")

    # mlflow logging
    mlflow_uri = get_key_def('mlflow_uri', params['global'], default="./mlruns")
    Path(mlflow_uri).mkdir(exist_ok=True)
    experiment_name = get_key_def('mlflow_experiment_name', params['global'], default='gdl-training', expected_type=str)
    run_name = get_key_def('mlflow_run_name', params['global'], default='gdl', expected_type=str)

    # parameters to find hdf5 samples
    data_path = Path(get_key_def('data_path', params['global'], './data', expected_type=str))
    samples_size = get_key_def("samples_size", params["global"], default=1024, expected_type=int)
    overlap = get_key_def("overlap", params["sample"], default=5, expected_type=int)
    min_annot_perc = get_key_def('min_annotated_percent', params['sample']['sampling_method'], default=0,
                                 expected_type=int)
    if not data_path.is_dir():
        raise FileNotFoundError(f'Could not locate data path {data_path}')
    samples_folder_name = (f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'
                           f'_{experiment_name}')
    samples_folder = data_path.joinpath(samples_folder_name)

    # visualization parameters
    vis_at_train = get_key_def('vis_at_train', params['visualization'], default=False)
    vis_at_eval = get_key_def('vis_at_evaluation', params['visualization'], default=False)
    vis_batch_range = get_key_def('vis_batch_range', params['visualization'], default=None)
    vis_at_checkpoint = get_key_def('vis_at_checkpoint', params['visualization'], default=False)
    ep_vis_min_thresh = get_key_def('vis_at_ckpt_min_ep_diff', params['visualization'], default=1, expected_type=int)
    vis_at_ckpt_dataset = get_key_def('vis_at_ckpt_dataset', params['visualization'], 'val')
    colormap_file = get_key_def('colormap_file', params['visualization'], None)
    heatmaps = get_key_def('heatmaps', params['visualization'], False)
    heatmaps_inf = get_key_def('heatmaps', params['inference'], False)
    grid = get_key_def('grid', params['visualization'], False)
    mean = get_key_def('mean', params['training']['normalization'])
    std = get_key_def('std', params['training']['normalization'])
    vis_params = {'colormap_file': colormap_file, 'heatmaps': heatmaps, 'heatmaps_inf': heatmaps_inf, 'grid': grid,
                  'mean': mean, 'std': std, 'vis_batch_range': vis_batch_range, 'vis_at_train': vis_at_train,
                  'vis_at_eval': vis_at_eval, 'ignore_index': dontcare_val, 'inference_input_path': None}

    # coordconv parameters
    coordconv_params = {}
    for param, val in params['global'].items():
        if 'coordconv' in param:
            coordconv_params[param] = val

    # add git hash from current commit to parameters if available. Parameters will be saved to model's .pth.tar
    params['global']['git_hash'] = get_git_hash()

    # automatic model naming with unique id for each training
    model_id = config_path.stem
    output_path = samples_folder.joinpath('model') / model_id
    if output_path.is_dir():
        last_mod_time_suffix = datetime.fromtimestamp(output_path.stat().st_mtime).strftime('%Y%m%d-%H%M%S')
        archive_output_path = samples_folder.joinpath('model') / f"{model_id}_{last_mod_time_suffix}"
        shutil.move(output_path, archive_output_path)
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))  # copy yaml to output path where model will be saved

    import logging.config  # See: https://docs.python.org/2.4/lib/logging-config-fileformat.html
    log_config_path = Path('utils/logging.conf').absolute()
    logfile = f'{output_path}/{model_id}.log'
    logfile_debug = f'{output_path}/{model_id}_debug.log'
    console_level_logging = 'INFO' if not debug else 'DEBUG'
    logging.config.fileConfig(log_config_path, defaults={'logfilename': logfile,
                                                         'logfilename_debug': logfile_debug,
                                                         'console_level': console_level_logging})

    logging.info(f'Model and log files will be saved to: {output_path}\n\n')
    if debug:
        logging.warning(f'Debug mode activated. Some debug features may mobilize extra disk space and '
                        f'cause delays in execution.')
    if dontcare_val < 0 and vis_batch_range:
        logging.warning(f'Visualization: expected positive value for ignore_index, got {dontcare_val}.'
                        f'Will be overridden to 255 during visualization only. Problems may occur.')

    # overwrite dontcare values in label if loss is not lovasz or crossentropy. FIXME: hacky fix.
    dontcare2backgr = False
    if loss_fn not in ['Lovasz', 'CrossEntropy', 'OhemCrossEntropy']:
        dontcare2backgr = True
        logging.warning(f'Dontcare is not implemented for loss function "{loss_fn}". '
                        f'Dontcare values ({dontcare_val}) in label will be replaced with background value (0)')

    # Will check if batch size needs to be a lower value only if cropping samples during training
    calc_eval_bs = True if crop_size else False

    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, model_name, criterion, optimizer, lr_scheduler, device, gpu_devices_dict = \
        net(model_name=model_name,
            num_bands=num_bands,
            num_channels=num_classes_corrected,
            dontcare_val=dontcare_val,
            num_devices=num_devices,
            train_state_dict_path=train_state_dict_path,
            pretrained=pretrained,
            dropout_prob=dropout_prob,
            loss_fn=loss_fn,
            class_weights=class_weights,
            optimizer=optimizer,
            net_params=params,
            conc_point=conc_point,
            coordconv_params=coordconv_params)

    logging.info(f'Instantiated {model_name} model with {num_classes_corrected} output channels.\n')

    logging.info(f'Creating dataloaders from data in {samples_folder}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_dataloader(samples_folder=samples_folder,
                                                                       batch_size=batch_size,
                                                                       eval_batch_size=eval_batch_size,
                                                                       gpu_devices_dict=gpu_devices_dict,
                                                                       sample_size=samples_size,
                                                                       dontcare_val=dontcare_val,
                                                                       crop_size=crop_size,
                                                                       meta_map=meta_map,
                                                                       num_bands=num_bands,
                                                                       BGR_to_RGB=BGR_to_RGB,
                                                                       scale=scale,
                                                                       params=params,
                                                                       dontcare2backgr=dontcare2backgr,
                                                                       calc_eval_bs=calc_eval_bs,
                                                                       debug=debug)


    # mlflow tracking path + parameters logging
    set_tracking_uri(mlflow_uri)
    set_experiment(experiment_name)
    start_run(run_name=run_name)
    log_params(params['training'])
    log_params(params['global'])
    log_params(params['sample'])

    if bucket_name:
        from utils.aws import download_s3_files
        bucket, bucket_output_path, output_path, data_path = download_s3_files(bucket_name=bucket_name,
                                                                               data_path=data_path,
                                                                               output_path=output_path)

    since = time.time()
    best_loss = 999
    last_vis_epoch = 0

    progress_log = output_path / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep', 'time'))  # Add header

    trn_log = InformationLogger('trn')
    val_log = InformationLogger('val')
    tst_log = InformationLogger('tst')
    filename = output_path.joinpath('checkpoint.pth.tar')

    # VISUALIZATION: generate pngs of inputs, labels and outputs
    if vis_batch_range is not None:
        # Make sure user-provided range is a tuple with 3 integers (start, finish, increment).
        # Check once for all visualization tasks.
        if not isinstance(vis_batch_range, list) and len(vis_batch_range) == 3 and all(isinstance(x, int)
                                                                                       for x in vis_batch_range):
            raise ValueError(f'Vis_batch_range expects three integers in a list: start batch, end batch, increment.'
                             f'Got {vis_batch_range}')
        vis_at_init_dataset = get_key_def('vis_at_init_dataset', params['visualization'], 'val')

        # Visualization at initialization. Visualize batch range before first eopch.
        if get_key_def('vis_at_init', params['visualization'], False):
            logging.info(f'Visualizing initialized model on batch range {vis_batch_range} '
                         f'from {vis_at_init_dataset} dataset...\n')
            vis_from_dataloader(vis_params=vis_params,
                                eval_loader=val_dataloader if vis_at_init_dataset == 'val' else tst_dataloader,
                                model=model,
                                ep_num=0,
                                output_path=output_path,
                                dataset=vis_at_init_dataset,
                                scale=scale,
                                device=device,
                                vis_batch_range=vis_batch_range)

    for epoch in range(0, num_epochs):
        logging.info(f'\nEpoch {epoch}/{num_epochs - 1}\n{"-" * 20}')

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes_corrected,
                           batch_size=batch_size,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           device=device,
                           scale=scale,
                           vis_params=vis_params,
                           debug=debug)
        trn_log.add_values(trn_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(eval_loader=val_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes_corrected,
                                batch_size=batch_size,
                                ep_idx=epoch,
                                progress_log=progress_log,
                                batch_metrics=batch_metrics,
                                dataset='val',
                                device=device,
                                scale=scale,
                                vis_params=vis_params,
                                debug=debug)
        val_loss = val_report['loss'].avg
        if batch_metrics is not None:
            val_log.add_values(val_report, epoch)
        else:
            val_log.add_values(val_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            logging.info("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict() if num_devices > 1 else model.state_dict()
            torch.save({'epoch': epoch,
                        'params': params,
                        'model': state_dict,
                        'best_loss': best_loss,
                        'optimizer': optimizer.state_dict()}, filename)
            if bucket_name:
                bucket_filename = bucket_output_path.joinpath('checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

            # VISUALIZATION: generate pngs of img samples, labels and outputs as alternative to follow training
            if vis_batch_range is not None and vis_at_checkpoint and epoch - last_vis_epoch >= ep_vis_min_thresh:
                if last_vis_epoch == 0:
                    logging.info(f'Visualizing with {vis_at_ckpt_dataset} dataset samples on checkpointed model for'
                                 f'batches in range {vis_batch_range}')
                vis_from_dataloader(vis_params=vis_params,
                                    eval_loader=val_dataloader if vis_at_ckpt_dataset == 'val' else tst_dataloader,
                                    model=model,
                                    ep_num=epoch+1,
                                    output_path=output_path,
                                    dataset=vis_at_ckpt_dataset,
                                    scale=scale,
                                    device=device,
                                    vis_batch_range=vis_batch_range)
                last_vis_epoch = epoch

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now, params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        logging.info(f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s')

    # load checkpoint model and evaluate it on test dataset.
    if num_epochs > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(eval_loader=tst_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes_corrected,
                                batch_size=batch_size,
                                ep_idx=num_epochs,
                                progress_log=progress_log,
                                batch_metrics=batch_metrics,
                                dataset='tst',
                                scale=scale,
                                vis_params=vis_params,
                                device=device)
        tst_log.add_values(tst_report, num_epochs)

        if bucket_name:
            bucket_filename = bucket_output_path.joinpath('last_epoch.pth.tar')
            bucket.upload_file("output.txt", bucket_output_path.joinpath(f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    log_params({'checkpoint path': filename})
    logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
def main(params, config_path):
    """
    Function to train and validate a model for semantic segmentation.

    Process
    -------
    1. Model is instantiated and checkpoint is loaded from path, if provided in
       `your_config.yaml`.
    2. GPUs are requested according to desired amount of `num_gpus` and
       available GPUs.
    3. If more than 1 GPU is requested, model is cast to DataParallel model
    4. Dataloaders are created with `create_dataloader()`
    5. Loss criterion, optimizer and learning rate are set with
       `set_hyperparameters()` as requested in `config.yaml`.
    5. Using these hyperparameters, the application will try to minimize the
       loss on the training data and evaluate every epoch on the validation
       data.
    6. For every epoch, the application shows and logs the loss on "trn" and
       "val" datasets.
    7. For every epoch (if `batch_metrics: 1`), the application shows and logs
       the accuracy, recall and f-score on "val" dataset. Those metrics are
       also computed on each class.
    8. At the end of the training process, the application shows and logs the
       accuracy, recall and f-score on "tst" dataset. Those metrics are also
       computed on each class.

    -------
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.
    """
    params['global']['git_hash'] = get_git_hash()
    debug = get_key_def('debug_mode', params['global'], False)
    if debug:
        warnings.warn(
            f'Debug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.'
        )

    now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
    num_classes = params['global']['num_classes']
    task = params['global']['task']
    assert task == 'segmentation', f"The task should be segmentation. The provided value is {task}"
    num_classes_corrected = num_classes + 1  # + 1 for background # FIXME temporary patch for num_classes problem.

    data_path = Path(params['global']['data_path'])
    assert data_path.is_dir(), f'Could not locate data path {data_path}'
    samples_size = params["global"]["samples_size"]
    overlap = params["sample"]["overlap"]
    min_annot_perc = get_key_def('min_annotated_percent',
                                 params['sample']['sampling_method'],
                                 0,
                                 expected_type=int)
    num_bands = params['global']['number_of_bands']
    samples_folder_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands'  # FIXME: won't check if folder has datetime suffix (if multiple folders)
    samples_folder = data_path.joinpath(samples_folder_name)
    batch_size = params['training']['batch_size']
    num_devices = params['global']['num_gpus']
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(
        num_devices) if torch.cuda.is_available() else []
    num_devices = len(lst_device_ids) if lst_device_ids else 0
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.
                          is_available() and lst_device_ids else 'cpu')

    tqdm.write(f'Creating dataloaders from data in {samples_folder}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_dataloader(
        samples_folder=samples_folder,
        batch_size=batch_size,
        num_devices=num_devices,
        params=params)
    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, model_name, criterion, optimizer, lr_scheduler = net(
        params,
        num_classes_corrected)  # pretrained could become a yaml parameter.
    tqdm.write(
        f'Instantiated {model_name} model with {num_classes_corrected} output channels.\n'
    )
    bucket_name = get_key_def('bucket_name', params['global'])

    # mlflow tracking path + parameters logging
    set_tracking_uri(
        get_key_def('mlflow_uri', params['global'], default="./mlruns"))
    set_experiment('gdl-training')
    log_params(params['training'])
    log_params(params['global'])
    log_params(params['sample'])

    modelname = config_path.stem
    output_path = samples_folder.joinpath('model') / modelname
    if output_path.is_dir():
        output_path = output_path.joinpath(f"_{now}")
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))
    tqdm.write(f'Model and log files will be saved to: {output_path}\n\n')

    if bucket_name:
        from utils.aws import download_s3_files
        bucket, bucket_output_path, output_path, data_path = download_s3_files(
            bucket_name=bucket_name,
            data_path=data_path,
            output_path=output_path)

    since = time.time()
    best_loss = 999
    last_vis_epoch = 0

    progress_log = output_path / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(
            tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep',
                     'time'))  # Add header

    trn_log = InformationLogger('trn')
    val_log = InformationLogger('val')
    tst_log = InformationLogger('tst')
    filename = output_path.joinpath('checkpoint.pth.tar')

    # VISUALIZATION: generate pngs of inputs, labels and outputs
    vis_batch_range = get_key_def('vis_batch_range', params['visualization'],
                                  None)
    if vis_batch_range is not None:
        # Make sure user-provided range is a tuple with 3 integers (start, finish, increment). Check once for all visualization tasks.
        assert isinstance(vis_batch_range,
                          list) and len(vis_batch_range) == 3 and all(
                              isinstance(x, int) for x in vis_batch_range)
        vis_at_init_dataset = get_key_def('vis_at_init_dataset',
                                          params['visualization'], 'val')

        # Visualization at initialization. Visualize batch range before first eopch.
        if get_key_def('vis_at_init', params['visualization'], False):
            tqdm.write(
                f'Visualizing initialized model on batch range {vis_batch_range} from {vis_at_init_dataset} dataset...\n'
            )
            vis_from_dataloader(
                params=params,
                eval_loader=val_dataloader
                if vis_at_init_dataset == 'val' else tst_dataloader,
                model=model,
                ep_num=0,
                output_path=output_path,
                dataset=vis_at_init_dataset,
                device=device,
                vis_batch_range=vis_batch_range)

    for epoch in range(0, params['training']['num_epochs']):
        print(
            f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}'
        )

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes_corrected,
                           batch_size=batch_size,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           vis_params=params,
                           device=device,
                           debug=debug)
        trn_log.add_values(trn_report,
                           epoch,
                           ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(
            eval_loader=val_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes_corrected,
            batch_size=batch_size,
            ep_idx=epoch,
            progress_log=progress_log,
            vis_params=params,
            batch_metrics=params['training']['batch_metrics'],
            dataset='val',
            device=device,
            debug=debug)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch)
        else:
            val_log.add_values(val_report,
                               epoch,
                               ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            tqdm.write("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict(
            ) if num_devices > 1 else model.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'params': params,
                    'model': state_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict()
                }, filename)
            if epoch == 0:
                log_artifact(filename)
            if bucket_name:
                bucket_filename = bucket_output_path.joinpath(
                    'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

            # VISUALIZATION: generate png of test samples, labels and outputs for visualisation to follow training performance
            vis_at_checkpoint = get_key_def('vis_at_checkpoint',
                                            params['visualization'], False)
            ep_vis_min_thresh = get_key_def('vis_at_ckpt_min_ep_diff',
                                            params['visualization'], 4)
            vis_at_ckpt_dataset = get_key_def('vis_at_ckpt_dataset',
                                              params['visualization'], 'val')
            if vis_batch_range is not None and vis_at_checkpoint and epoch - last_vis_epoch >= ep_vis_min_thresh:
                if last_vis_epoch == 0:
                    tqdm.write(
                        f'Visualizing with {vis_at_ckpt_dataset} dataset samples on checkpointed model for'
                        f'batches in range {vis_batch_range}')
                vis_from_dataloader(
                    params=params,
                    eval_loader=val_dataloader
                    if vis_at_ckpt_dataset == 'val' else tst_dataloader,
                    model=model,
                    ep_num=epoch + 1,
                    output_path=output_path,
                    dataset=vis_at_ckpt_dataset,
                    device=device,
                    vis_batch_range=vis_batch_range)
                last_vis_epoch = epoch

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now,
                                params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        print(
            f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s'
        )

    # load checkpoint model and evaluate it on test dataset.
    if int(
            params['training']['num_epochs']
    ) > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(
            eval_loader=tst_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes_corrected,
            batch_size=batch_size,
            ep_idx=params['training']['num_epochs'],
            progress_log=progress_log,
            vis_params=params,
            batch_metrics=params['training']['batch_metrics'],
            dataset='tst',
            device=device)
        tst_log.add_values(tst_report, params['training']['num_epochs'])

        if bucket_name:
            bucket_filename = bucket_output_path.joinpath('last_epoch.pth.tar')
            bucket.upload_file(
                "output.txt",
                bucket_output_path.joinpath(f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
def main(params, config_path):
    """
    Function to train and validate a models for semantic segmentation or classification.
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.

    """
    # MANDATORY PARAMETERS
    num_classes = get_key_def('num_classes', params['global'], expected_type=int)
    num_bands = get_key_def('number_of_bands', params['global'], expected_type=int)
    batch_size = get_key_def('batch_size', params['training'], expected_type=int)
    num_epochs = get_key_def('num_epochs', params['training'], expected_type=int)
    model_name = get_key_def('model_name', params['global'], expected_type=str).lower()
    BGR_to_RGB = get_key_def('BGR_to_RGB', params['global'], expected_type=bool)

    # parameters to find hdf5 samples
    data_path = Path(get_key_def('data_path', params['global'], './data', expected_type=str))
    assert data_path.is_dir(), f'Could not locate data path {data_path}'

    # OPTIONAL PARAMETERS
    # basics
    debug = get_key_def('debug_mode', params['global'], default=False, expected_type=bool)
    task = get_key_def('task', params['global'], default='classification', expected_type=str)
    assert task == 'classification', f"The task should be classification. The provided value is {task}"
    dontcare_val = get_key_def("ignore_index", params["training"], default=-1, expected_type=int)
    batch_metrics = get_key_def('batch_metrics', params['training'], default=1, expected_type=int)
    meta_map = get_key_def("meta_map", params["global"], default={})
    bucket_name = get_key_def('bucket_name', params['global'])  # AWS

    # model params
    loss_fn = get_key_def('loss_fn', params['training'], default='CrossEntropy', expected_type=str)
    optimizer = get_key_def('optimizer', params['training'], default='adam', expected_type=str)
    pretrained = get_key_def('pretrained', params['training'], default=True, expected_type=bool)
    train_state_dict_path = get_key_def('state_dict_path', params['training'], default=None, expected_type=str)
    dropout_prob = get_key_def('dropout_prob', params['training'], default=None, expected_type=float)

    # gpu parameters
    num_devices = get_key_def('num_gpus', params['global'], default=0, expected_type=int)
    max_used_ram = get_key_def('max_used_ram', params['global'], default=2000, expected_type=int)
    max_used_perc = get_key_def('max_used_perc', params['global'], default=15, expected_type=int)

    # automatic model naming with unique id for each training
    model_id = config_path.stem
    output_path = data_path.joinpath('model') / model_id
    if output_path.is_dir():
        last_mod_time_suffix = datetime.fromtimestamp(output_path.stat().st_mtime).strftime('%Y%m%d-%H%M%S')
        archive_output_path = data_path.joinpath('model') / f"{model_id}_{last_mod_time_suffix}"
        shutil.move(output_path, archive_output_path)
    output_path.mkdir(parents=True, exist_ok=False)
    shutil.copy(str(config_path), str(output_path))  # copy yaml to output path where model will be saved
    tqdm.write(f'Model and log files will be saved to: {output_path}\n\n')

    if debug:
        warnings.warn(f'Debug mode activated. Some debug functions may cause delays in execution.')

    if bucket_name:
        bucket, bucket_output_path, output_path, data_path = download_s3_files(bucket_name=bucket_name,
                                                                               data_path=data_path,
                                                                               output_path=output_path,
                                                                               num_classes=num_classes)
    elif not bucket_name:
        get_local_classes(num_classes, data_path, output_path)

    since = time.time()
    now = datetime.now().strftime("%Y-%m-%d_%H-%M")
    best_loss = 999

    progress_log = Path(output_path) / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep', 'time'))  # Add header

    trn_log = InformationLogger('trn')
    val_log = InformationLogger('val')
    tst_log = InformationLogger('tst')

    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    gpu_devices_dict = get_device_ids(num_devices,
                                      max_used_ram_perc=max_used_ram,
                                      max_used_perc=max_used_perc)
    num_devices = len(gpu_devices_dict.keys())
    device = torch.device(f'cuda:0' if gpu_devices_dict else 'cpu')

    tqdm.write(f'Creating dataloaders from data in {Path(data_path)}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_classif_dataloader(data_path=data_path,
                                                                               batch_size=batch_size,
                                                                               num_devices=num_devices,)

    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, model_name, criterion, optimizer, lr_scheduler = net(model_name=model_name,
                                                                num_bands=num_bands,
                                                                num_channels=num_classes,
                                                                dontcare_val=dontcare_val,
                                                                num_devices=num_devices,
                                                                train_state_dict_path=train_state_dict_path,
                                                                pretrained=pretrained,
                                                                dropout_prob=dropout_prob,
                                                                loss_fn=loss_fn,
                                                                optimizer=optimizer,
                                                                net_params=params)
    tqdm.write(f'Instantiated {model_name} model with {num_classes} output channels.\n')

    filename = os.path.join(output_path, 'checkpoint.pth.tar')

    for epoch in range(0, params['training']['num_epochs']):
        logging.info(f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}')

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes,
                           batch_size=batch_size,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           device=device,
                           debug=debug)
        trn_log.add_values(trn_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(eval_loader=val_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes,
                                batch_size=batch_size,
                                ep_idx=epoch,
                                progress_log=progress_log,
                                batch_metrics=params['training']['batch_metrics'],
                                dataset='val',
                                device=device,
                                debug=debug)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch, ignore=['iou'])
        else:
            val_log.add_values(val_report, epoch, ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            tqdm.write("save checkpoint\n")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict() if num_devices > 1 else model.state_dict()
            torch.save({'epoch': epoch,
                        'params': params,
                        'model': state_dict,
                        'best_loss': best_loss,
                        'optimizer': optimizer.state_dict()}, filename)

            if bucket_name:
                bucket_filename = os.path.join(bucket_output_path, 'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, batch_metrics)

        cur_elapsed = time.time() - since
        logging.info(f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s')

    # load checkpoint model and evaluate it on test dataset.
    if int(params['training']['num_epochs']) > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(eval_loader=tst_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes,
                                batch_size=batch_size,
                                ep_idx=num_epochs,
                                progress_log=progress_log,
                                batch_metrics=batch_metrics,
                                dataset='tst',
                                device=device)
        tst_log.add_values(tst_report, num_epochs, ignore=['iou'])

        if bucket_name:
            bucket_filename = os.path.join(bucket_output_path, 'last_epoch.pth.tar')
            bucket.upload_file("output.txt", os.path.join(bucket_output_path, f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Esempio n. 7
0
def train(cfg: DictConfig) -> None:
    """
    Function to train and validate a model for semantic segmentation.

    -------

    1. Model is instantiated and checkpoint is loaded from path, if provided in
       `your_config.yaml`.
    2. GPUs are requested according to desired amount of `num_gpus` and
       available GPUs.
    3. If more than 1 GPU is requested, model is cast to DataParallel model
    4. Dataloaders are created with `create_dataloader()`
    5. Loss criterion, optimizer and learning rate are set with
       `set_hyperparameters()` as requested in `config.yaml`.
    5. Using these hyperparameters, the application will try to minimize the
       loss on the training data and evaluate every epoch on the validation
       data.
    6. For every epoch, the application shows and logs the loss on "trn" and
       "val" datasets.
    7. For every epoch (if `batch_metrics: 1`), the application shows and logs
       the accuracy, recall and f-score on "val" dataset. Those metrics are
       also computed on each class.
    8. At the end of the training process, the application shows and logs the
       accuracy, recall and f-score on "tst" dataset. Those metrics are also
       computed on each class.

    -------
    :param cfg: (dict) Parameters found in the yaml config file.
    """
    now = datetime.now().strftime("%Y-%m-%d_%H-%M")

    # MANDATORY PARAMETERS
    num_classes = len(get_key_def('classes_dict', cfg['dataset']).keys())
    num_classes_corrected = num_classes + 1  # + 1 for background # FIXME temporary patch for num_classes problem.
    num_bands = len(read_modalities(cfg.dataset.modalities))
    batch_size = get_key_def('batch_size', cfg['training'], expected_type=int)
    eval_batch_size = get_key_def('eval_batch_size',
                                  cfg['training'],
                                  expected_type=int,
                                  default=batch_size)
    num_epochs = get_key_def('max_epochs', cfg['training'], expected_type=int)
    model_name = get_key_def('model_name', cfg['model'],
                             expected_type=str).lower()
    # TODO need to keep in parameters? see victor stuff
    # BGR_to_RGB = get_key_def('BGR_to_RGB', params['global'], expected_type=bool)
    BGR_to_RGB = False

    # OPTIONAL PARAMETERS
    debug = get_key_def('debug', cfg)
    task = get_key_def('task_name', cfg['task'], default='segmentation')
    dontcare_val = get_key_def("ignore_index", cfg['dataset'], default=-1)
    bucket_name = get_key_def('bucket_name', cfg['AWS'])
    scale = get_key_def('scale_data', cfg['augmentation'], default=[0, 1])
    batch_metrics = get_key_def('batch_metrics', cfg['training'], default=None)
    meta_map = get_key_def("meta_map", cfg['training'],
                           default=None)  # TODO what is that?
    crop_size = get_key_def('target_size', cfg['training'], default=None)
    # if error
    if meta_map and not Path(meta_map).is_file():
        raise logging.critical(
            FileNotFoundError(f'\nCouldn\'t locate {meta_map}'))
    if task != 'segmentation':
        raise logging.critical(
            ValueError(
                f"\nThe task should be segmentation. The provided value is {task}"
            ))

    # MODEL PARAMETERS
    class_weights = get_key_def('class_weights', cfg['dataset'], default=None)
    loss_fn = get_key_def('loss_fn', cfg['training'], default='CrossEntropy')
    optimizer = get_key_def(
        'optimizer_name', cfg['optimizer'], default='adam',
        expected_type=str)  # TODO change something to call the function
    pretrained = get_key_def('pretrained',
                             cfg['model'],
                             default=True,
                             expected_type=bool)
    train_state_dict_path = get_key_def('state_dict_path',
                                        cfg['general'],
                                        default=None,
                                        expected_type=str)
    dropout_prob = get_key_def('factor',
                               cfg['scheduler']['params'],
                               default=None,
                               expected_type=float)
    # if error
    if train_state_dict_path and not Path(train_state_dict_path).is_file():
        raise logging.critical(
            FileNotFoundError(
                f'\nCould not locate pretrained checkpoint for training: {train_state_dict_path}'
            ))
    if class_weights:
        verify_weights(num_classes_corrected, class_weights)
    # Read the concatenation point
    # TODO: find a way to maybe implement it in classification one day
    conc_point = None
    # conc_point = get_key_def('concatenate_depth', params['global'], None)

    # GPU PARAMETERS
    num_devices = get_key_def('num_gpus', cfg['training'], default=0)
    if num_devices and not num_devices >= 0:
        raise logging.critical(
            ValueError("\nmissing mandatory num gpus parameter"))
    default_max_used_ram = 15
    max_used_ram = get_key_def('max_used_ram',
                               cfg['training'],
                               default=default_max_used_ram)
    max_used_perc = get_key_def('max_used_perc', cfg['training'], default=15)

    # LOGGING PARAMETERS TODO put option not just mlflow
    experiment_name = get_key_def('project_name',
                                  cfg['general'],
                                  default='gdl-training')
    try:
        tracker_uri = get_key_def('uri', cfg['tracker'])
        Path(tracker_uri).mkdir(exist_ok=True)
        run_name = get_key_def(
            'run_name', cfg['tracker'],
            default='gdl')  # TODO change for something meaningful
    # meaning no logging tracker as been assigned or it doesnt exist in config/logging
    except ConfigKeyError:
        logging.info(
            "\nNo logging tracker as been assigned or the yaml config doesnt exist in 'config/tracker'."
            "\nNo tracker file will be saved in this case.")

    # PARAMETERS FOR hdf5 SAMPLES
    # info on the hdf5 name
    samples_size = get_key_def("input_dim",
                               cfg['dataset'],
                               expected_type=int,
                               default=256)
    overlap = get_key_def("overlap",
                          cfg['dataset'],
                          expected_type=int,
                          default=0)
    min_annot_perc = get_key_def('min_annotated_percent',
                                 cfg['dataset'],
                                 expected_type=int,
                                 default=0)
    samples_folder_name = (
        f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands_{experiment_name}'
    )
    try:
        my_hdf5_path = Path(str(
            cfg.dataset.sample_data_dir)).resolve(strict=True)
        samples_folder = Path(
            my_hdf5_path.joinpath(samples_folder_name)).resolve(strict=True)
        logging.info("\nThe HDF5 directory used '{}'".format(samples_folder))
    except FileNotFoundError:
        samples_folder = Path(str(
            cfg.dataset.sample_data_dir)).joinpath(samples_folder_name)
        logging.info(
            f"\nThe HDF5 directory '{samples_folder}' doesn't exist, please change the path."
            +
            f"\nWe will try to find '{samples_folder_name}' in '{cfg.dataset.raw_data_dir}'."
        )
        try:
            my_data_path = Path(cfg.dataset.raw_data_dir).resolve(strict=True)
            samples_folder = Path(
                my_data_path.joinpath(samples_folder_name)).resolve(
                    strict=True)
            logging.info(
                "\nThe HDF5 directory used '{}'".format(samples_folder))
            cfg.general.sample_data_dir = str(
                my_data_path
            )  # need to be done for when the config will be saved
        except FileNotFoundError:
            raise logging.critical(
                f"\nThe HDF5 directory '{samples_folder_name}' doesn't exist in '{cfg.dataset.raw_data_dir}'"
                +
                f"\n or in '{cfg.dataset.sample_data_dir}', please verify the location of your HDF5."
            )

    # visualization parameters
    vis_at_train = get_key_def('vis_at_train',
                               cfg['visualization'],
                               default=False)
    vis_at_eval = get_key_def('vis_at_evaluation',
                              cfg['visualization'],
                              default=False)
    vis_batch_range = get_key_def('vis_batch_range',
                                  cfg['visualization'],
                                  default=None)
    vis_at_checkpoint = get_key_def('vis_at_checkpoint',
                                    cfg['visualization'],
                                    default=False)
    ep_vis_min_thresh = get_key_def('vis_at_ckpt_min_ep_diff',
                                    cfg['visualization'],
                                    default=1)
    vis_at_ckpt_dataset = get_key_def('vis_at_ckpt_dataset',
                                      cfg['visualization'], 'val')
    colormap_file = get_key_def('colormap_file', cfg['visualization'], None)
    heatmaps = get_key_def('heatmaps', cfg['visualization'], False)
    heatmaps_inf = get_key_def('heatmaps', cfg['inference'], False)
    grid = get_key_def('grid', cfg['visualization'], False)
    mean = get_key_def('mean', cfg['augmentation']['normalization'])
    std = get_key_def('std', cfg['augmentation']['normalization'])
    vis_params = {
        'colormap_file': colormap_file,
        'heatmaps': heatmaps,
        'heatmaps_inf': heatmaps_inf,
        'grid': grid,
        'mean': mean,
        'std': std,
        'vis_batch_range': vis_batch_range,
        'vis_at_train': vis_at_train,
        'vis_at_eval': vis_at_eval,
        'ignore_index': dontcare_val,
        'inference_input_path': None
    }

    # coordconv parameters TODO
    # coordconv_params = {}
    # for param, val in params['global'].items():
    #     if 'coordconv' in param:
    #         coordconv_params[param] = val
    coordconv_params = get_key_def('coordconv', cfg['model'])

    # automatic model naming with unique id for each training
    config_path = None
    for list_path in cfg.general.config_path:
        if list_path['provider'] == 'main':
            config_path = list_path['path']
    config_name = str(cfg.general.config_name)
    model_id = config_name
    output_path = Path(f'model/{model_id}')
    output_path.mkdir(parents=True, exist_ok=False)
    logging.info(
        f'\nModel and log files will be saved to: {os.getcwd()}/{output_path}')
    if debug:
        logging.warning(
            f'\nDebug mode activated. Some debug features may mobilize extra disk space and '
            f'cause delays in execution.')
    if dontcare_val < 0 and vis_batch_range:
        logging.warning(
            f'\nVisualization: expected positive value for ignore_index, got {dontcare_val}.'
            f'Will be overridden to 255 during visualization only. Problems may occur.'
        )

    # overwrite dontcare values in label if loss is not lovasz or crossentropy. FIXME: hacky fix.
    dontcare2backgr = False
    if loss_fn not in ['Lovasz', 'CrossEntropy', 'OhemCrossEntropy']:
        dontcare2backgr = True
        logging.warning(
            f'\nDontcare is not implemented for loss function "{loss_fn}". '
            f'\nDontcare values ({dontcare_val}) in label will be replaced with background value (0)'
        )

    # Will check if batch size needs to be a lower value only if cropping samples during training
    calc_eval_bs = True if crop_size else False
    # INSTANTIATE MODEL AND LOAD CHECKPOINT FROM PATH
    model, model_name, criterion, optimizer, lr_scheduler, device, gpu_devices_dict = \
        net(model_name=model_name,
            num_bands=num_bands,
            num_channels=num_classes_corrected,
            dontcare_val=dontcare_val,
            num_devices=num_devices,
            train_state_dict_path=train_state_dict_path,
            pretrained=pretrained,
            dropout_prob=dropout_prob,
            loss_fn=loss_fn,
            class_weights=class_weights,
            optimizer=optimizer,
            net_params=cfg,
            conc_point=conc_point,
            coordconv_params=coordconv_params)

    logging.info(
        f'Instantiated {model_name} model with {num_classes_corrected} output channels.\n'
    )

    logging.info(f'Creating dataloaders from data in {samples_folder}...\n')
    trn_dataloader, val_dataloader, tst_dataloader = create_dataloader(
        samples_folder=samples_folder,
        batch_size=batch_size,
        eval_batch_size=eval_batch_size,
        gpu_devices_dict=gpu_devices_dict,
        sample_size=samples_size,
        dontcare_val=dontcare_val,
        crop_size=crop_size,
        meta_map=meta_map,
        num_bands=num_bands,
        BGR_to_RGB=BGR_to_RGB,
        scale=scale,
        cfg=cfg,
        dontcare2backgr=dontcare2backgr,
        calc_eval_bs=calc_eval_bs,
        debug=debug)

    # Save tracking TODO put option not just mlflow
    if 'tracker_uri' in locals() and 'run_name' in locals():
        mode = get_key_def('mode', cfg, expected_type=str)
        task = get_key_def('task_name', cfg['task'], expected_type=str)
        run_name = '{}_{}_{}'.format(run_name, mode, task)
        # tracking path + parameters logging
        set_tracking_uri(tracker_uri)
        set_experiment(experiment_name)
        start_run(run_name=run_name)
        log_params(dict_path(cfg, 'training'))
        log_params(dict_path(cfg, 'dataset'))
        log_params(dict_path(cfg, 'model'))
        log_params(dict_path(cfg, 'optimizer'))
        log_params(dict_path(cfg, 'scheduler'))
        log_params(dict_path(cfg, 'augmentation'))
        # TODO change something to not only have the mlflow option
        trn_log = InformationLogger('trn')
        val_log = InformationLogger('val')
        tst_log = InformationLogger('tst')

    if bucket_name:
        from utils.aws import download_s3_files
        bucket, bucket_output_path, output_path, data_path = download_s3_files(
            bucket_name=bucket_name,
            data_path=data_path,  # FIXME
            output_path=output_path)

    since = time.time()
    best_loss = 999
    last_vis_epoch = 0

    progress_log = output_path / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(
            tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep',
                     'time'))  # Add header

    # create the checkpoint file
    filename = output_path.joinpath('checkpoint.pth.tar')

    # VISUALIZATION: generate pngs of inputs, labels and outputs
    if vis_batch_range is not None:
        # Make sure user-provided range is a tuple with 3 integers (start, finish, increment).
        # Check once for all visualization tasks.
        if not isinstance(vis_batch_range,
                          list) and len(vis_batch_range) == 3 and all(
                              isinstance(x, int) for x in vis_batch_range):
            raise logging.critical(
                ValueError(
                    f'\nVis_batch_range expects three integers in a list: start batch, end batch, increment.'
                    f'Got {vis_batch_range}'))
        vis_at_init_dataset = get_key_def('vis_at_init_dataset',
                                          cfg['visualization'], 'val')

        # Visualization at initialization. Visualize batch range before first eopch.
        if get_key_def('vis_at_init', cfg['visualization'], False):
            logging.info(
                f'\nVisualizing initialized model on batch range {vis_batch_range} '
                f'from {vis_at_init_dataset} dataset...\n')
            vis_from_dataloader(
                vis_params=vis_params,
                eval_loader=val_dataloader
                if vis_at_init_dataset == 'val' else tst_dataloader,
                model=model,
                ep_num=0,
                output_path=output_path,
                dataset=vis_at_init_dataset,
                scale=scale,
                device=device,
                vis_batch_range=vis_batch_range)

    for epoch in range(0, num_epochs):
        logging.info(f'\nEpoch {epoch}/{num_epochs - 1}\n' +
                     "-" * len(f'Epoch {epoch}/{num_epochs - 1}'))
        # creating trn_report
        trn_report = training(train_loader=trn_dataloader,
                              model=model,
                              criterion=criterion,
                              optimizer=optimizer,
                              scheduler=lr_scheduler,
                              num_classes=num_classes_corrected,
                              batch_size=batch_size,
                              ep_idx=epoch,
                              progress_log=progress_log,
                              device=device,
                              scale=scale,
                              vis_params=vis_params,
                              debug=debug)
        if 'trn_log' in locals():  # only save the value if a tracker is setup
            trn_log.add_values(trn_report,
                               epoch,
                               ignore=['precision', 'recall', 'fscore', 'iou'])
        val_report = evaluation(eval_loader=val_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes_corrected,
                                batch_size=batch_size,
                                ep_idx=epoch,
                                progress_log=progress_log,
                                batch_metrics=batch_metrics,
                                dataset='val',
                                device=device,
                                scale=scale,
                                vis_params=vis_params,
                                debug=debug)
        val_loss = val_report['loss'].avg
        if 'val_log' in locals():  # only save the value if a tracker is setup
            if batch_metrics is not None:
                val_log.add_values(val_report, epoch)
            else:
                val_log.add_values(
                    val_report,
                    epoch,
                    ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            logging.info(
                "\nSave checkpoint with a validation loss of {:.4f}".format(
                    val_loss))  # only allow 4 decimals
            best_loss = val_loss
            # More info:
            # https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict(
            ) if num_devices > 1 else model.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'params': cfg,
                    'model': state_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict()
                }, filename)
            if bucket_name:
                bucket_filename = bucket_output_path.joinpath(
                    'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

            # VISUALIZATION: generate pngs of img samples, labels and outputs as alternative to follow training
            if vis_batch_range is not None and vis_at_checkpoint and epoch - last_vis_epoch >= ep_vis_min_thresh:
                if last_vis_epoch == 0:
                    logging.info(
                        f'\nVisualizing with {vis_at_ckpt_dataset} dataset samples on checkpointed model for'
                        f'batches in range {vis_batch_range}')
                vis_from_dataloader(
                    vis_params=vis_params,
                    eval_loader=val_dataloader
                    if vis_at_ckpt_dataset == 'val' else tst_dataloader,
                    model=model,
                    ep_num=epoch + 1,
                    output_path=output_path,
                    dataset=vis_at_ckpt_dataset,
                    scale=scale,
                    device=device,
                    vis_batch_range=vis_batch_range)
                last_vis_epoch = epoch

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now,
                                cfg['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        # logging.info(f'\nCurrent elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s')

    # copy the checkpoint in 'save_weights_dir'
    Path(cfg['general']['save_weights_dir']).mkdir(parents=True, exist_ok=True)
    copy(filename, cfg['general']['save_weights_dir'])

    # load checkpoint model and evaluate it on test dataset.
    if int(
            cfg['general']['max_epochs']
    ) > 0:  # if num_epochs is set to 0, model is loaded to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)

    if tst_dataloader:
        tst_report = evaluation(eval_loader=tst_dataloader,
                                model=model,
                                criterion=criterion,
                                num_classes=num_classes_corrected,
                                batch_size=batch_size,
                                ep_idx=num_epochs,
                                progress_log=progress_log,
                                batch_metrics=batch_metrics,
                                dataset='tst',
                                scale=scale,
                                vis_params=vis_params,
                                device=device)
        if 'tst_log' in locals():  # only save the value if a tracker is setup
            tst_log.add_values(tst_report, num_epochs)

        if bucket_name:
            bucket_filename = bucket_output_path.joinpath('last_epoch.pth.tar')
            bucket.upload_file(
                "output.txt",
                bucket_output_path.joinpath(f"Logs/{now}_output.txt"))
            bucket.upload_file(filename, bucket_filename)
Esempio n. 8
0
def main(params, config_path):
    """
    Function to train and validate a models for semantic segmentation or classification.
    :param params: (dict) Parameters found in the yaml config file.
    :param config_path: (str) Path to the yaml config file.

    """
    now = datetime.datetime.now().strftime("%Y-%m-%d_%I-%M")

    model, checkpoint, model_name = net(params)
    bucket_name = params['global']['bucket_name']
    data_path = params['global']['data_path']
    modelname = config_path.stem
    output_path = Path(data_path).joinpath('model') / modelname
    try:
        output_path.mkdir(parents=True, exist_ok=False)
    except FileExistsError:
        output_path = Path(str(output_path) + '_' + now)
        output_path.mkdir(exist_ok=True)
    print(f'Model and log files will be saved to: {output_path}')
    task = params['global']['task']
    num_classes = params['global']['num_classes']
    batch_size = params['training']['batch_size']

    if num_classes == 1:
        # assume background is implicitly needed (makes no sense to train with one class otherwise)
        # this will trigger some warnings elsewhere, but should succeed nonetheless
        num_classes = 2

    if bucket_name:
        bucket, bucket_output_path, output_path, data_path = download_s3_files(
            bucket_name=bucket_name,
            data_path=data_path,
            output_path=output_path,
            num_classes=num_classes,
            task=task)

    elif not bucket_name and task == 'classification':
        get_local_classes(num_classes, data_path, output_path)

    since = time.time()
    best_loss = 999

    progress_log = Path(output_path) / 'progress.log'
    if not progress_log.exists():
        progress_log.open('w', buffering=1).write(
            tsv_line('ep_idx', 'phase', 'iter', 'i_p_ep',
                     'time'))  # Add header

    trn_log = InformationLogger(output_path, 'trn')
    val_log = InformationLogger(output_path, 'val')
    tst_log = InformationLogger(output_path, 'tst')

    num_devices = params['global']['num_gpus']
    assert num_devices is not None and num_devices >= 0, "missing mandatory num gpus parameter"
    # list of GPU devices that are available and unused. If no GPUs, returns empty list
    lst_device_ids = get_device_ids(
        num_devices) if torch.cuda.is_available() else []
    num_devices = len(lst_device_ids) if lst_device_ids else 0
    device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.
                          is_available() and lst_device_ids else 'cpu')
    if num_devices == 1:
        print(f"Using Cuda device {lst_device_ids[0]}")
    elif num_devices > 1:
        print(f"Using data parallel on devices {str(lst_device_ids)[1:-1]}")
        model = nn.DataParallel(model, device_ids=lst_device_ids
                                )  # adds prefix 'module.' to state_dict keys
    else:
        warnings.warn(
            f"No Cuda device available. This process will only run on CPU")

    trn_dataloader, val_dataloader, tst_dataloader = create_dataloader(
        data_path=data_path,
        batch_size=batch_size,
        task=task,
        num_devices=num_devices,
        params=params)

    model, criterion, optimizer, lr_scheduler = set_hyperparameters(
        params, model, checkpoint)

    criterion = criterion.to(device)
    model = model.to(device)

    filename = os.path.join(output_path, 'checkpoint.pth.tar')

    for epoch in range(0, params['training']['num_epochs']):
        print(
            f'\nEpoch {epoch}/{params["training"]["num_epochs"] - 1}\n{"-" * 20}'
        )

        trn_report = train(train_loader=trn_dataloader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=lr_scheduler,
                           num_classes=num_classes,
                           batch_size=batch_size,
                           task=task,
                           ep_idx=epoch,
                           progress_log=progress_log,
                           device=device)
        trn_log.add_values(trn_report,
                           epoch,
                           ignore=['precision', 'recall', 'fscore', 'iou'])

        val_report = evaluation(
            eval_loader=val_dataloader,
            model=model,
            criterion=criterion,
            num_classes=num_classes,
            batch_size=batch_size,
            task=task,
            ep_idx=epoch,
            progress_log=progress_log,
            batch_metrics=params['training']['batch_metrics'],
            dataset='val',
            device=device)
        val_loss = val_report['loss'].avg
        if params['training']['batch_metrics'] is not None:
            val_log.add_values(val_report, epoch)
        else:
            val_log.add_values(val_report,
                               epoch,
                               ignore=['precision', 'recall', 'fscore', 'iou'])

        if val_loss < best_loss:
            print("save checkpoint")
            best_loss = val_loss
            # More info: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models
            state_dict = model.module.state_dict(
            ) if num_devices > 1 else model.state_dict()
            torch.save(
                {
                    'epoch': epoch,
                    'arch': model_name,
                    'model': state_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict()
                }, filename)

            if bucket_name:
                bucket_filename = os.path.join(bucket_output_path,
                                               'checkpoint.pth.tar')
                bucket.upload_file(filename, bucket_filename)

        if bucket_name:
            save_logs_to_bucket(bucket, bucket_output_path, output_path, now,
                                params['training']['batch_metrics'])

        cur_elapsed = time.time() - since
        print(
            f'Current elapsed time {cur_elapsed // 60:.0f}m {cur_elapsed % 60:.0f}s'
        )

    # load checkpoint model and evaluate it on test dataset.
    if int(
            params['training']['num_epochs']
    ) > 0:  #if num_epochs is set to 0, is loaded model to evaluate on test set
        checkpoint = load_checkpoint(filename)
        model, _ = load_from_checkpoint(checkpoint, model)
    tst_report = evaluation(eval_loader=tst_dataloader,
                            model=model,
                            criterion=criterion,
                            num_classes=num_classes,
                            batch_size=batch_size,
                            task=task,
                            ep_idx=params['training']['num_epochs'],
                            progress_log=progress_log,
                            batch_metrics=params['training']['batch_metrics'],
                            dataset='tst',
                            device=device)
    tst_log.add_values(tst_report, params['training']['num_epochs'])

    if bucket_name:
        bucket_filename = os.path.join(bucket_output_path,
                                       'last_epoch.pth.tar')
        bucket.upload_file(
            "output.txt",
            os.path.join(bucket_output_path, f"Logs/{now}_output.txt"))
        bucket.upload_file(filename, bucket_filename)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))