Ejemplo n.º 1
0
def get_train_and_val_datasets(hparams, no_val, train_on_val, logger):
    """
    Return all pairs of (train, validation) SleepStudyDatasets as described in
    the YAMLHParams object 'hparams'. A list is returned, as more than 1
    dataset may be described in the parameter file.

    Also returns an updated version of 'no_val', see below. Specifically, if
    'train_on_val' is True, then no_val will be set to true no matter its
    initial value.

    Args:
        hparams:      (YAMLHParams) A hyperparameter object to load dataset
                                    configurations from.
        no_val:       (bool)        Do not load validation data
        train_on_val: (bool)        Load validation data, but merge it into
                                    the training data. Then return only the
                                    'trainin' (train+val) dataset.
        logger:       (Logger)      A Logger object

    Returns:
        A list of training SleepStudyDataset objects
        A list of validation SleepStudyDataset objects, or [] if not val.
    """
    if no_val:
        load = ("train_data", )
        if train_on_val:
            raise ValueError("Should not specify --no_val with --train_on_val")
    else:
        load = ("train_data", "val_data")
    from utime.utils.scriptutils import get_splits_from_all_datasets
    datasets = [*get_splits_from_all_datasets(hparams, load, logger)]
    if train_on_val:
        if any([len(ds) != 2 for ds in datasets]):
            raise ValueError("Did not find a validation set for one or more "
                             "pairs in {}".format(datasets))
        logger("[OBS] Merging training and validation sets")
        datasets = [merge_train_and_val(*ds) for ds in datasets]
        no_val = True
    if not no_val:
        train_datasets, val_datasets = zip(*datasets)
    else:
        train_datasets = [d[0] for d in datasets]
        val_datasets = []
    return train_datasets, val_datasets
Ejemplo n.º 2
0
def run(args):
    """
    Run the script according to args - Please refer to the argparser.
    """
    assert_args(args)
    # Check project folder is valid
    from utime.utils.scriptutils import (assert_project_folder,
                                         get_splits_from_all_datasets)
    project_dir = os.path.abspath(args.project_dir)
    assert_project_folder(project_dir, evaluation=True)

    # Prepare output dir
    out_dir = get_out_dir(args.out_dir, args.data_split)
    prepare_output_dir(out_dir, args.overwrite)
    logger = get_logger(out_dir, args.overwrite)
    logger("Args dump: \n{}".format(vars(args)))

    # Get hyperparameters and init all described datasets
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(project_dir + "/hparams.yaml", logger)
    if args.channels:
        hparams["select_channels"] = args.channels
        hparams["channel_sampling_groups"] = None
        logger("Evaluating using channels {}".format(args.channels))

    # Get model
    set_gpu_vis(args.num_GPUs, args.force_GPU, logger)
    model, model_func = None, None
    if args.one_shot:
        # Model is initialized for each sleep study later
        def model_func(full_hyp):
            return get_and_load_one_shot_model(full_hyp, project_dir, hparams,
                                               logger, args.weights_file_name)
    else:
        model = get_and_load_model(project_dir, hparams, logger,
                                   args.weights_file_name)

    # Run predictions on all datasets
    datasets = get_splits_from_all_datasets(hparams=hparams,
                                            splits_to_load=(args.data_split, ),
                                            logger=logger)
    eval_dirs = []
    for dataset in datasets:
        dataset = dataset[0]
        if "/" in dataset.identifier:
            # Multiple datasets, separate results into sub-folders
            ds_out_dir = os.path.join(out_dir,
                                      dataset.identifier.split("/")[0])
            if not os.path.exists(ds_out_dir):
                os.mkdir(ds_out_dir)
            eval_dirs.append(ds_out_dir)
        else:
            ds_out_dir = out_dir
        logger("[*] Running eval on dataset {}\n"
               "    Out dir: {}".format(dataset, ds_out_dir))
        run_pred_and_eval(dataset=dataset,
                          out_dir=ds_out_dir,
                          model=model,
                          model_func=model_func,
                          hparams=hparams,
                          args=args,
                          logger=logger)
    if len(eval_dirs) > 1:
        cross_dataset_eval(eval_dirs, out_dir)
Ejemplo n.º 3
0
def run(args):
    """
    Run the script according to args - Please refer to the argparser.
    """
    assert_args(args)
    # Check project folder is valid
    from utime.utils.scriptutils import (assert_project_folder,
                                         get_dataset_from_regex_pattern,
                                         get_splits_from_all_datasets,
                                         get_all_dataset_hparams)
    project_dir = os.path.abspath(args.project_dir)
    assert_project_folder(project_dir, evaluation=True)

    # Prepare output dir
    if not args.folder_regex:
        out_dir = get_out_dir(args.out_dir, args.data_split)
    else:
        out_dir = args.out_dir
    prepare_output_dir(out_dir, args.overwrite)
    logger = get_logger(out_dir, args.overwrite, name="prediction_log")
    logger("Args dump: \n{}".format(vars(args)))

    # Get hyperparameters and init all described datasets
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(project_dir + "/hparams.yaml", logger)
    hparams["build"]["data_per_prediction"] = args.data_per_prediction
    if args.channels:
        hparams["select_channels"] = args.channels
        hparams["channel_sampling_groups"] = None
        logger("Evaluating using channels {}".format(args.channels))

    # Get model
    set_gpu_vis(args.num_GPUs, args.force_GPU, logger)
    model, model_func = None, None
    if args.one_shot:
        # Model is initialized for each sleep study later
        def model_func(full_hyp):
            return get_and_load_one_shot_model(full_hyp, project_dir, hparams,
                                               logger, args.weights_file_name)
    else:
        model = get_and_load_model(project_dir, hparams, logger,
                                   args.weights_file_name)

    if args.folder_regex:
        # We predict on a single dataset, specified by the folder_regex arg
        # We load the dataset hyperparameters of one of those specified in
        # the stored hyperparameter files and use it as a guide for how to
        # handle this new, undescribed dataset
        dataset_hparams = list(get_all_dataset_hparams(hparams).values())[0]
        datasets = [(get_dataset_from_regex_pattern(args.folder_regex,
                                                    hparams=dataset_hparams,
                                                    logger=logger), )]
    else:
        # predict on datasets described in the hyperparameter files
        datasets = get_splits_from_all_datasets(
            hparams=hparams, splits_to_load=(args.data_split, ), logger=logger)

    for dataset in datasets:
        dataset = dataset[0]
        if "/" in dataset.identifier:
            # Multiple datasets, separate results into sub-folders
            ds_out_dir = os.path.join(out_dir,
                                      dataset.identifier.split("/")[0])
            if not os.path.exists(ds_out_dir):
                os.mkdir(ds_out_dir)
        else:
            ds_out_dir = out_dir
        logger("[*] Running eval on dataset {}\n"
               "    Out dir: {}".format(dataset, ds_out_dir))
        run_pred(dataset=dataset,
                 out_dir=ds_out_dir,
                 model=model,
                 model_func=model_func,
                 hparams=hparams,
                 args=args,
                 logger=logger)
Ejemplo n.º 4
0
def run(args):
    """
    Run the script according to args - Please refer to the argparser.

    args:
        args:    (Namespace)  command-line arguments
    """
    from mpunet.logging import Logger
    from utime.hyperparameters import YAMLHParams
    from utime.utils.scriptutils import assert_project_folder
    from utime.utils.scriptutils import get_splits_from_all_datasets

    project_dir = os.path.abspath("./")
    assert_project_folder(project_dir)

    # Get logger object
    logger = Logger(project_dir + "/preprocessing_logs",
                    active_file='preprocessing',
                    overwrite_existing=args.overwrite,
                    no_sub_folder=True)
    logger("Args dump: {}".format(vars(args)))

    # Load hparams
    hparams = YAMLHParams(Defaults.get_hparams_path(project_dir),
                          logger=logger,
                          no_version_control=True)

    # Initialize and load (potentially multiple) datasets
    datasets = get_splits_from_all_datasets(hparams,
                                            splits_to_load=args.dataset_splits,
                                            logger=logger,
                                            return_data_hparams=True)

    # Check if file exists, and overwrite if specified
    if os.path.exists(args.out_path):
        if args.overwrite:
            os.remove(args.out_path)
        else:
            from sys import exit
            logger("Out file at {} exists, and --overwrite was not set."
                   "".format(args.out_path))
            exit(0)

    # Create dataset hparams output directory
    out_dir = Defaults.get_pre_processed_data_configurations_dir(project_dir)
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    with ThreadPoolExecutor(args.num_threads) as pool:
        with h5py.File(args.out_path, "w") as h5_file:
            for dataset, dataset_hparams in datasets:
                # Create a new version of the dataset-specific hyperparameters
                # that contain only the fields needed for pre-processed data
                name = dataset[0].identifier.split("/")[0]
                hparams_out_path = os.path.join(out_dir, name + ".yaml")
                copy_dataset_hparams(dataset_hparams, hparams_out_path)

                # Update paths to dataset hparams in main hparams file
                hparams.set_value(subdir='datasets',
                                  name=name,
                                  value=hparams_out_path,
                                  overwrite=True)
                # Save the hyperparameters to the pre-processed main hparams
                hparams.save_current(
                    Defaults.get_pre_processed_hparams_path(project_dir))

                # Process each dataset
                for split in dataset:
                    # Add this split to the dataset-specific hparams
                    add_dataset_entry(hparams_out_path, args.out_path,
                                      split.identifier.split("/")[-1].lower(),
                                      split.period_length_sec)
                    # Overwrite potential load time channel sampler to None
                    split.set_load_time_channel_sampling_groups(None)

                    # Create dataset group
                    split_group = h5_file.create_group(split.identifier)

                    # Run the preprocessing
                    process_func = partial(preprocess_study, split_group)

                    logger.print_to_screen = True
                    logger("Preprocessing dataset:", split)
                    logger.print_to_screen = False
                    n_pairs = len(split.pairs)
                    for i, _ in enumerate(pool.map(process_func, split.pairs)):
                        print("  {}/{}".format(i + 1, n_pairs),
                              end='\r',
                              flush=True)
                    print("")