Ejemplo n.º 1
0
def check_params(datasets, models, results_path, parameters, metrics,
                 csv_filename):
    assert len(datasets) > 0, "dataset parameter is not well defined."
    assert all(os.path.exists(ds_path)
               for ds_path in datasets), "dataset paths are not well defined."
    assert all(param in parameters.keys() for param in [
        "normalization_method",
        "past_history_factor",
        "batch_size",
        "epochs",
        "max_steps_per_epoch",
        "learning_rate",
        "model_params",
    ]), "Some parameters are missing in the parameters file."
    assert all(model in parameters["model_params"]
               for model in models), "models parameter is not well defined."
    assert metrics is None or all(m in METRICS.keys() for m in metrics)
Ejemplo n.º 2
0
def main(args):
    datasets = args.datasets
    models = args.models
    results_path = args.output
    gpu_device = args.gpu
    metrics = args.metrics
    csv_filename = args.csv_filename

    parameters = None
    with open(args.parameters, "r") as params_file:
        parameters = json.load(params_file)

    check_params(datasets, models, results_path, parameters, metrics, csv_filename)

    if len(models) == 0:
        models = list(parameters["model_params"].keys())

    if metrics is None:
        metrics = list(METRICS.keys())

    if not os.path.exists(results_path):
        os.makedirs(results_path)

    for dataset_index, dataset_path in enumerate(datasets):
        dataset = os.path.basename(os.path.normpath(dataset_path))

        csv_filepath = results_path + "/{}/{}".format(dataset, csv_filename)
        results = read_results_file(csv_filepath, metrics)
        current_index = results.shape[0]
        print("CURRENT INDEX", current_index)

        experiments_index = 0
        num_total_experiments = np.prod(
            [len(parameters[k]) for k in parameters.keys() if k != "model_params"]
            + [
                np.sum(
                    [
                        np.prod(
                            [
                                len(parameters["model_params"][m][k])
                                for k in parameters["model_params"][m].keys()
                            ]
                        )
                        for m in models
                    ]
                )
            ]
        )

        for epochs, normalization_method, past_history_factor in itertools.product(
                parameters["epochs"],
                parameters["normalization_method"],
                parameters["past_history_factor"],
        ):
            for batch_size, learning_rate in itertools.product(
                    parameters["batch_size"], parameters["learning_rate"],
            ):
                for model_name in models:
                    for model_index, model_args in enumerate(
                            product(**parameters["model_params"][model_name])
                    ):
                        experiments_index += 1
                        if experiments_index <= current_index:
                            continue

                        # Run each experiment in a new Process to avoid GPU memory leaks
                        manager = Manager()
                        error_dict = manager.dict()

                        p = Process(
                            target=run_experiment,
                            args=(
                                error_dict,
                                gpu_device,
                                dataset,
                                dataset_path,
                                results_path,
                                csv_filepath,
                                metrics,
                                epochs,
                                normalization_method,
                                past_history_factor,
                                parameters["max_steps_per_epoch"][0],
                                batch_size,
                                learning_rate,
                                model_name,
                                model_index,
                                model_args,
                            ),
                        )
                        p.start()
                        p.join()

                        assert error_dict["status"] == 1, error_dict["message"]

                        notify_slack(
                            "{}/{} {}:{}/{} ({})".format(
                                dataset_index + 1,
                                len(datasets),
                                dataset,
                                experiments_index,
                                num_total_experiments,
                                model_name,
                            )
                        )
Ejemplo n.º 3
0
 parser.add_argument("--test_label_path", type=str)
 parser.add_argument("--model_path", type=str)
 parser.add_argument("--checkpoint_dir", type=str, required=True)
 parser.add_argument("--exp_name", type=str, required=True)
 parser.add_argument("--da",
                     type=str,
                     nargs='+',
                     default=[],
                     choices=[
                         'flip', 'blur', 'noise', 'resized_crop', 'affine',
                         'ghosting', 'motion', 'spike', 'biasfield', 'swap'
                     ])
 parser.add_argument("--metrics",
                     nargs='+',
                     type=str,
                     choices=list(METRICS.keys()),
                     help="Metrics to be computed on validation/test set")
 parser.add_argument("--labels",
                     nargs='+',
                     type=str,
                     help="Label(s) to be predicted")
 parser.add_argument("--loss",
                     type=str,
                     choices=['BCE', 'l1', 'l2', 'GaussianLogLkd'],
                     required=True)
 parser.add_argument("--net",
                     type=str,
                     choices=[
                         "resnet18", "resnet34", "resnet50", "resnext50",
                         "vgg11", "vgg16", "sfcn", "densenet121",
                         "tiny_densenet121", "tiny_vgg"