Esempio n. 1
0
def run(args, return_prediction=False, dump_args=None):
    """
    Run the script according to args - Please refer to the argparser.
    """
    assert_args(args)
    # Check project folder is valid
    from utime.utils.scriptutils.scriptutils import assert_project_folder
    project_dir = os.path.abspath(args.project_dir)
    assert_project_folder(project_dir, evaluation=True)

    # Get a logger
    logger = get_logger(project_dir, True, name="prediction_log")
    if dump_args:
        logger("Args dump: \n{}".format(vars(args)))

    # Get hyperparameters and init all described datasets
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(Defaults.get_hparams_path(project_dir),
                          logger,
                          no_version_control=True)

    # Get the sleep study
    logger("Loading and pre-processing PSG file...")
    hparams['prediction_params']['channels'] = args.channels
    study, channel_groups = get_sleep_study(psg_path=args.f,
                                            logger=logger,
                                            **hparams['prediction_params'])

    # Set GPU and get model
    set_gpu_vis(args.num_GPUs, args.force_GPU, logger)
    hparams["build"]["data_per_prediction"] = args.data_per_prediction
    logger("Predicting with {} data per prediction".format(
        args.data_per_prediction))
    model = get_and_load_one_shot_model(
        n_periods=study.n_periods,
        project_dir=project_dir,
        hparams=hparams,
        logger=logger,
        weights_file_name=hparams.get_from_anywhere('weight_file_name'))
    logger("Predicting...")
    pred = predict_study(study, model, channel_groups, args.no_argmax, logger)
    logger("--> Predicted shape: {}".format(pred.shape))
    if return_prediction:
        return pred
    else:
        save_prediction(pred=pred,
                        out_path=args.o,
                        input_file_path=study.psg_file_path,
                        logger=logger)
Esempio n. 2
0
def get_all_dataset_hparams(hparams):
    """
    Takes a YAMLHParams object and returns a dictionary of one or more entries
    of dataset ID to YAMLHParams objects pairs; one for each dataset described
    in 'hparams'.

    If 'hparams' has the 'datasets' attribute each mentioned dataset under this
    field will be loaded and returned. Otherwise, it is assumed that a single
    dataset is described directly in 'hparams', in which case 'hparams' as-is
    will be the only returned value (with no ID).

    Args:
        hparams: (YAMLHParams) A hyperparameter object storing reference to
                               one or more datasets in the 'datasets' field, or
                               directly in 'hparams.

    Returns:
        A dictonary if dataset ID to YAMLHParams object pairs
        One entry for each dataset
    """
    from utime.hyperparameters import YAMLHParams
    dataset_hparams = {}
    if hparams.get("datasets"):
        # Multiple datasets specified in hparams configuration files
        ids_and_paths = hparams["datasets"].items()
        for id_, path in ids_and_paths:
            yaml_path = os.path.join(hparams.project_path, path)
            dataset_hparams[id_] = YAMLHParams(yaml_path,
                                               no_log=True,
                                               no_version_control=True)
    else:
        # Return as-is with no ID
        dataset_hparams[""] = hparams
    return dataset_hparams
Esempio n. 3
0
def copy_yaml_and_set_data_dirs(in_path, out_path, data_dir=None):
    """
    Creates a YAMLHParams object from a in_path (a hyperparameter .yaml file),
    inserts the 'data_dir' argument into data_dir fileds in the .yaml file
    (if present) and saves the hyperparameter file to out_path.

    Note: If data_dir is set, it is assumed that the folder contains data in
          sub-folders 'train', 'val' and 'test' (not required to exist).

    args:
        in_path:  (string) Path to a .yaml file storing the hyperparameters
        out_path: (string) Path to save the hyperparameters to
        data_dir: (string) Optional path to a directory storing data to use
                           for this project.
    """
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(in_path, no_log=True, no_version_control=True)

    # Set values in parameter file and save to new location
    data_ids = ("train", "val", "test")
    for dataset in data_ids:
        path = os.path.join(data_dir, dataset) if data_dir else "Null"
        dataset = dataset + "_data"
        if hparams.get(dataset) and not hparams[dataset].get("data_dir"):
            hparams.set_value(dataset, "data_dir", path, True, True)
    hparams.save_current(out_path)
Esempio n. 4
0
def get_dataset_splits_from_hparams_file(hparams_path,
                                         splits_to_load,
                                         logger=None,
                                         id=""):
    """
    Loads one or more datasets according to hyperparameters described in yaml
    file at path 'hparams_path'. Specifically, this functions creates a temp.
    YAMLHparams object from the yaml file data and applies redirects to the
    'get_dataset_splits_from_hparams' function.

    Please refer to the docstring of 'get_dataset_splits_from_hparams' for
    details.
    """
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(hparams_path, no_log=True, no_version_control=True)
    return get_dataset_splits_from_hparams(hparams, splits_to_load, logger, id)
Esempio n. 5
0
def prepare_hparams_dir(hparams_dir):
    if not os.path.exists(hparams_dir):
        # Check local hparams.yaml file, move into hparams_dir
        if os.path.exists("hparams.yaml"):
            os.mkdir(hparams_dir)
            hparams = YAMLHParams("hparams.yaml",
                                  no_log=True,
                                  no_version_control=True)
            for dataset, path in hparams['datasets'].items():
                destination = os.path.join(hparams_dir, path)
                os.makedirs(os.path.dirname(destination), exist_ok=True)
                shutil.move(path, destination)
            shutil.move("hparams.yaml", hparams_dir)
        else:
            raise RuntimeError("Must specifiy hyperparameters in a folder at path --hparams_prototype_dir <path> OR " + \
                               "have a hparams.yaml file at the current working directory (i.e. project folder)")
Esempio n. 6
0
def run(args, gpu_mon):
    """
    Run the script according to args - Please refer to the argparser.

    args:
        args:    (Namespace)  command-line arguments
        gpu_mon: (GPUMonitor) Initialized mpunet GPUMonitor object
    """
    assert_args(args)
    from mpunet.logging import Logger
    from utime.train import Trainer
    from utime.hyperparameters import YAMLHParams
    from utime.utils.scriptutils import (assert_project_folder,
                                         make_multi_gpu_model)
    from utime.utils.scriptutils.train import (get_train_and_val_datasets,
                                               get_h5_train_and_val_datasets,
                                               get_data_queues, get_generators,
                                               find_and_set_gpus,
                                               get_samples_per_epoch,
                                               save_final_weights)

    project_dir = os.path.abspath("./")
    assert_project_folder(project_dir)
    if args.overwrite and not args.continue_training:
        from mpunet.bin.train import remove_previous_session
        remove_previous_session(project_dir)

    # Get logger object
    logger = Logger(project_dir,
                    overwrite_existing=args.overwrite,
                    append_existing=args.continue_training,
                    log_prefix=args.log_file_prefix)
    logger("Args dump: {}".format(vars(args)))

    # Settings depending on --preprocessed flag.
    if args.preprocessed:
        yaml_path = utime.Defaults.get_pre_processed_hparams_path(project_dir)
        dataset_func = get_h5_train_and_val_datasets
        train_queue_type = 'eager'
        val_queue_type = 'eager'
    else:
        yaml_path = utime.Defaults.get_hparams_path(project_dir)
        dataset_func = get_train_and_val_datasets
        train_queue_type = args.train_queue_type
        val_queue_type = args.val_queue_type

    # Load hparams
    hparams = YAMLHParams(yaml_path, logger=logger)
    update_hparams_with_command_line_arguments(hparams, args)

    # Initialize and load (potentially multiple) datasets
    train_datasets, val_datasets = dataset_func(hparams, args.no_val,
                                                args.train_on_val, logger)

    if args.just:
        keep_n_random(*train_datasets,
                      *val_datasets,
                      keep=args.just,
                      logger=logger)

    # Get a data loader queue object for each dataset
    train_datasets_queues = get_data_queues(
        datasets=train_datasets,
        queue_type=train_queue_type,
        max_loaded_per_dataset=args.max_loaded_per_dataset,
        num_access_before_reload=args.num_access_before_reload,
        logger=logger)
    if val_datasets:
        val_dataset_queues = get_data_queues(
            datasets=val_datasets,
            queue_type=val_queue_type,
            max_loaded_per_dataset=args.max_loaded_per_dataset,
            num_access_before_reload=args.num_access_before_reload,
            study_loader=getattr(train_datasets_queues[0], 'study_loader',
                                 None),
            logger=logger)
    else:
        val_dataset_queues = None

    # Get sequence generators for all datasets
    train_seq, val_seq = get_generators(train_datasets_queues,
                                        val_dataset_queues=val_dataset_queues,
                                        hparams=hparams)

    # Add additional (inferred) parameters to parameter file
    hparams.set_value("build",
                      "n_classes",
                      train_seq.n_classes,
                      overwrite=True)
    hparams.set_value("build",
                      "batch_shape",
                      train_seq.batch_shape,
                      overwrite=True)
    hparams.save_current()

    if args.continue_training:
        # Prepare the project directory for continued training.
        # Please refer to the function docstring for details
        from utime.models.model_init import prepare_for_continued_training
        parameter_file = prepare_for_continued_training(
            hparams=hparams, project_dir=project_dir, logger=logger)
    else:
        parameter_file = args.initialize_from  # most often is None

    # Set the GPU visibility
    num_GPUs = find_and_set_gpus(gpu_mon, args.force_GPU, args.num_GPUs)
    # Initialize and potential load parameters into the model
    from utime.models.model_init import init_model, load_from_file
    org_model = init_model(hparams["build"], logger)
    if parameter_file:
        load_from_file(org_model, parameter_file, logger, by_name=True)
    model, org_model = make_multi_gpu_model(org_model, num_GPUs)

    # Prepare a trainer object. Takes care of compiling and training.
    trainer = Trainer(model, org_model=org_model, logger=logger)

    import tensorflow as tf
    trainer.compile_model(n_classes=hparams["build"].get("n_classes"),
                          reduction=tf.keras.losses.Reduction.NONE,
                          **hparams["fit"])

    # Fit the model on a number of samples as specified in args
    samples_pr_epoch = get_samples_per_epoch(train_seq,
                                             args.max_train_samples_per_epoch)

    _ = trainer.fit(train=train_seq,
                    val=val_seq,
                    train_samples_per_epoch=samples_pr_epoch,
                    **hparams["fit"])

    # Save weights to project_dir/model/{final_weights_file_name}.h5
    # Note: these weights are rarely used, as a checkpoint callback also saves
    # weights to this directory through training
    save_final_weights(project_dir,
                       model=model,
                       file_name=args.final_weights_file_name,
                       logger=logger)
Esempio n. 7
0
def run(args):
    """
    Run the script according to args - Please refer to the argparser.
    """
    assert_args(args)
    # Check project folder is valid
    from utime.utils.scriptutils import (assert_project_folder,
                                         get_splits_from_all_datasets)
    project_dir = os.path.abspath(args.project_dir)
    assert_project_folder(project_dir, evaluation=True)

    # Prepare output dir
    out_dir = get_out_dir(args.out_dir, args.data_split)
    prepare_output_dir(out_dir, args.overwrite)
    logger = get_logger(out_dir, args.overwrite)
    logger("Args dump: \n{}".format(vars(args)))

    # Get hyperparameters and init all described datasets
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(project_dir + "/hparams.yaml", logger)
    if args.channels:
        hparams["select_channels"] = args.channels
        hparams["channel_sampling_groups"] = None
        logger("Evaluating using channels {}".format(args.channels))

    # Get model
    set_gpu_vis(args.num_GPUs, args.force_GPU, logger)
    model, model_func = None, None
    if args.one_shot:
        # Model is initialized for each sleep study later
        def model_func(full_hyp):
            return get_and_load_one_shot_model(full_hyp, project_dir, hparams,
                                               logger, args.weights_file_name)
    else:
        model = get_and_load_model(project_dir, hparams, logger,
                                   args.weights_file_name)

    # Run predictions on all datasets
    datasets = get_splits_from_all_datasets(hparams=hparams,
                                            splits_to_load=(args.data_split, ),
                                            logger=logger)
    eval_dirs = []
    for dataset in datasets:
        dataset = dataset[0]
        if "/" in dataset.identifier:
            # Multiple datasets, separate results into sub-folders
            ds_out_dir = os.path.join(out_dir,
                                      dataset.identifier.split("/")[0])
            if not os.path.exists(ds_out_dir):
                os.mkdir(ds_out_dir)
            eval_dirs.append(ds_out_dir)
        else:
            ds_out_dir = out_dir
        logger("[*] Running eval on dataset {}\n"
               "    Out dir: {}".format(dataset, ds_out_dir))
        run_pred_and_eval(dataset=dataset,
                          out_dir=ds_out_dir,
                          model=model,
                          model_func=model_func,
                          hparams=hparams,
                          args=args,
                          logger=logger)
    if len(eval_dirs) > 1:
        cross_dataset_eval(eval_dirs, out_dir)
Esempio n. 8
0
def run(args, gpu_mon):
    """
    Run the script according to args - Please refer to the argparser.

    args:
        args:    (Namespace)  command-line arguments
        gpu_mon: (GPUMonitor) Initialized MultiPlanarUNet GPUMonitor object
    """
    assert_args(args)
    from mpunet.logging import Logger
    from utime.train import Trainer
    from utime.hyperparameters import YAMLHParams
    from utime.utils.scriptutils import (assert_project_folder,
                                         make_multi_gpu_model)
    from utime.utils.scriptutils.train import (get_train_and_val_datasets,
                                               get_generators,
                                               find_and_set_gpus,
                                               get_samples_per_epoch,
                                               save_final_weights)

    project_dir = os.path.abspath("./")
    assert_project_folder(project_dir)
    if args.overwrite and not args.continue_training:
        from mpunet.bin.train import remove_previous_session
        remove_previous_session(project_dir)

    # Get logger object
    logger = Logger(project_dir,
                    overwrite_existing=args.overwrite
                    or args.continue_training,
                    log_prefix=args.log_file_prefix)
    logger("Args dump: {}".format(vars(args)))

    # Load hparams
    hparams = YAMLHParams(os.path.join(project_dir, "hparams.yaml"),
                          logger=logger)
    update_hparams_with_command_line_arguments(hparams, args)

    # Initialize and load (potentially multiple) datasets
    datasets, no_val = get_train_and_val_datasets(hparams, args.no_val,
                                                  args.train_on_val, logger)

    # Load data in all datasets
    for data in datasets:
        for d in data:
            d.load(1 if args.just_one else None)
            d.pairs = d.loaded_pairs  # remove the other pairs

    # Get sequence generators for all datasets
    train_seq, val_seq = get_generators(datasets, hparams, no_val)

    # Add additional (inferred) parameters to parameter file
    hparams.set_value("build",
                      "n_classes",
                      train_seq.n_classes,
                      overwrite=True)
    hparams.set_value("build",
                      "batch_shape",
                      train_seq.batch_shape,
                      overwrite=True)
    hparams.save_current()

    if args.continue_training:
        # Prepare the project directory for continued training.
        # Please refer to the function docstring for details
        from utime.models.model_init import prepare_for_continued_training
        parameter_file = prepare_for_continued_training(
            hparams=hparams, project_dir=project_dir, logger=logger)
    else:
        parameter_file = args.initialize_from  # most often is None

    # Set the GPU visibility
    num_GPUs = find_and_set_gpus(gpu_mon, args.force_GPU, args.num_GPUs)
    # Initialize and potential load parameters into the model
    from utime.models.model_init import init_model, load_from_file
    org_model = init_model(hparams["build"], logger)
    if parameter_file:
        load_from_file(org_model, parameter_file, logger, by_name=True)
    model, org_model = make_multi_gpu_model(org_model, num_GPUs)

    # Prepare a trainer object. Takes care of compiling and training.
    trainer = Trainer(model, org_model=org_model, logger=logger)
    trainer.compile_model(n_classes=hparams["build"].get("n_classes"),
                          **hparams["fit"])

    # Fit the model on a number of samples as specified in args
    samples_pr_epoch = get_samples_per_epoch(train_seq,
                                             args.max_train_samples_per_epoch,
                                             args.val_samples_per_epoch)
    _ = trainer.fit(train=train_seq,
                    val=val_seq,
                    train_samples_per_epoch=samples_pr_epoch[0],
                    val_samples_per_epoch=samples_pr_epoch[1],
                    **hparams["fit"])

    # Save weights to project_dir/model/{final_weights_file_name}.h5
    # Note: these weights are rarely used, as a checkpoint callback also saves
    # weights to this directory through training
    save_final_weights(project_dir,
                       model=model,
                       file_name=args.final_weights_file_name,
                       logger=logger)
Esempio n. 9
0
def run(args):
    """
    Run the script according to args - Please refer to the argparser.
    """
    assert_args(args)
    # Check project folder is valid
    from utime.utils.scriptutils import (assert_project_folder,
                                         get_dataset_from_regex_pattern,
                                         get_splits_from_all_datasets,
                                         get_all_dataset_hparams)
    project_dir = os.path.abspath(args.project_dir)
    assert_project_folder(project_dir, evaluation=True)

    # Prepare output dir
    if not args.folder_regex:
        out_dir = get_out_dir(args.out_dir, args.data_split)
    else:
        out_dir = args.out_dir
    prepare_output_dir(out_dir, args.overwrite)
    logger = get_logger(out_dir, args.overwrite, name="prediction_log")
    logger("Args dump: \n{}".format(vars(args)))

    # Get hyperparameters and init all described datasets
    from utime.hyperparameters import YAMLHParams
    hparams = YAMLHParams(project_dir + "/hparams.yaml", logger)
    hparams["build"]["data_per_prediction"] = args.data_per_prediction
    if args.channels:
        hparams["select_channels"] = args.channels
        hparams["channel_sampling_groups"] = None
        logger("Evaluating using channels {}".format(args.channels))

    # Get model
    set_gpu_vis(args.num_GPUs, args.force_GPU, logger)
    model, model_func = None, None
    if args.one_shot:
        # Model is initialized for each sleep study later
        def model_func(full_hyp):
            return get_and_load_one_shot_model(full_hyp, project_dir, hparams,
                                               logger, args.weights_file_name)
    else:
        model = get_and_load_model(project_dir, hparams, logger,
                                   args.weights_file_name)

    if args.folder_regex:
        # We predict on a single dataset, specified by the folder_regex arg
        # We load the dataset hyperparameters of one of those specified in
        # the stored hyperparameter files and use it as a guide for how to
        # handle this new, undescribed dataset
        dataset_hparams = list(get_all_dataset_hparams(hparams).values())[0]
        datasets = [(get_dataset_from_regex_pattern(args.folder_regex,
                                                    hparams=dataset_hparams,
                                                    logger=logger), )]
    else:
        # predict on datasets described in the hyperparameter files
        datasets = get_splits_from_all_datasets(
            hparams=hparams, splits_to_load=(args.data_split, ), logger=logger)

    for dataset in datasets:
        dataset = dataset[0]
        if "/" in dataset.identifier:
            # Multiple datasets, separate results into sub-folders
            ds_out_dir = os.path.join(out_dir,
                                      dataset.identifier.split("/")[0])
            if not os.path.exists(ds_out_dir):
                os.mkdir(ds_out_dir)
        else:
            ds_out_dir = out_dir
        logger("[*] Running eval on dataset {}\n"
               "    Out dir: {}".format(dataset, ds_out_dir))
        run_pred(dataset=dataset,
                 out_dir=ds_out_dir,
                 model=model,
                 model_func=model_func,
                 hparams=hparams,
                 args=args,
                 logger=logger)
Esempio n. 10
0
def run(args):
    """
    Run the script according to args - Please refer to the argparser.

    args:
        args:    (Namespace)  command-line arguments
    """
    from mpunet.logging import Logger
    from utime.hyperparameters import YAMLHParams
    from utime.utils.scriptutils import assert_project_folder
    from utime.utils.scriptutils import get_splits_from_all_datasets

    project_dir = os.path.abspath("./")
    assert_project_folder(project_dir)

    # Get logger object
    logger = Logger(project_dir + "/preprocessing_logs",
                    active_file='preprocessing',
                    overwrite_existing=args.overwrite,
                    no_sub_folder=True)
    logger("Args dump: {}".format(vars(args)))

    # Load hparams
    hparams = YAMLHParams(Defaults.get_hparams_path(project_dir),
                          logger=logger,
                          no_version_control=True)

    # Initialize and load (potentially multiple) datasets
    datasets = get_splits_from_all_datasets(hparams,
                                            splits_to_load=args.dataset_splits,
                                            logger=logger,
                                            return_data_hparams=True)

    # Check if file exists, and overwrite if specified
    if os.path.exists(args.out_path):
        if args.overwrite:
            os.remove(args.out_path)
        else:
            from sys import exit
            logger("Out file at {} exists, and --overwrite was not set."
                   "".format(args.out_path))
            exit(0)

    # Create dataset hparams output directory
    out_dir = Defaults.get_pre_processed_data_configurations_dir(project_dir)
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    with ThreadPoolExecutor(args.num_threads) as pool:
        with h5py.File(args.out_path, "w") as h5_file:
            for dataset, dataset_hparams in datasets:
                # Create a new version of the dataset-specific hyperparameters
                # that contain only the fields needed for pre-processed data
                name = dataset[0].identifier.split("/")[0]
                hparams_out_path = os.path.join(out_dir, name + ".yaml")
                copy_dataset_hparams(dataset_hparams, hparams_out_path)

                # Update paths to dataset hparams in main hparams file
                hparams.set_value(subdir='datasets',
                                  name=name,
                                  value=hparams_out_path,
                                  overwrite=True)
                # Save the hyperparameters to the pre-processed main hparams
                hparams.save_current(
                    Defaults.get_pre_processed_hparams_path(project_dir))

                # Process each dataset
                for split in dataset:
                    # Add this split to the dataset-specific hparams
                    add_dataset_entry(hparams_out_path, args.out_path,
                                      split.identifier.split("/")[-1].lower(),
                                      split.period_length_sec)
                    # Overwrite potential load time channel sampler to None
                    split.set_load_time_channel_sampling_groups(None)

                    # Create dataset group
                    split_group = h5_file.create_group(split.identifier)

                    # Run the preprocessing
                    process_func = partial(preprocess_study, split_group)

                    logger.print_to_screen = True
                    logger("Preprocessing dataset:", split)
                    logger.print_to_screen = False
                    n_pairs = len(split.pairs)
                    for i, _ in enumerate(pool.map(process_func, split.pairs)):
                        print("  {}/{}".format(i + 1, n_pairs),
                              end='\r',
                              flush=True)
                    print("")