예제 #1
0
def create_runner_parser(model_config_class: type = None) -> argparse.ArgumentParser:
    """
    Creates a commandline parser, that understands all necessary arguments for running a script in Azure,
    plus all arguments for the given class. The class must be a subclass of GenericConfig.
    :param model_config_class: A class that contains the model-specific parameters.
    :return: An instance of ArgumentParser.
    """
    parser = AzureConfig.create_argparser()
    ModelConfigLoader.add_args(parser)
    if model_config_class is not None:
        if not issubclass(model_config_class, GenericConfig):
            raise ValueError(f"The given class must be a subclass of GenericConfig, but got: {model_config_class}")
        model_config_class.add_args(parser)

    return parser
def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        help="The name of the model to train",
                        type=empty_string_to_none,
                        required=True)
    args = parser.parse_args()
    model_train(ModelConfigLoader().create_model_config_from_name(args.model))
def find_models() -> List[str]:
    """
    Lists all Python files in the configs folder. Each of them is assumed to contain one model config.
    :return: list of models
    """
    path = namespace_to_path(ModelConfigLoader.get_default_search_module())
    folders = [path / "segmentation", path / "classification", path / "regression"]
    names = [str(f.stem) for folder in folders for f in folder.glob("*.py") if folder.exists()]
    return [name for name in names if not name.endswith("Base") and not name.startswith("__")]
예제 #4
0
def get_configs(
        default_model_config: SegmentationModelBase, yaml_file_path: Path
) -> Tuple[SegmentationModelBase, AzureConfig, Dict]:
    parser_result = create_parser(yaml_file_path)
    args = parser_result.args
    runner_config = AzureConfig(**args)
    logging_to_stdout(args["log_level"])
    config = default_model_config or ModelConfigLoader(
    ).create_model_config_from_name(runner_config.model)
    config.apply_overrides(parser_result.overrides, should_validate=False)
    return config, runner_config, args
예제 #5
0
 def parse_and_load_model(self) -> Optional[ParserResult]:
     """
     Parses the command line arguments, and creates configuration objects for the model itself, and for the
     Azure-related parameters. Sets self.azure_config and self.model_config to their proper values. Returns the
     parser output from parsing the model commandline arguments.
     If no "model" argument is provided on the commandline, self.model_config will be set to None, and the return
     value is None.
     """
     # Create a parser that will understand only the args we need for an AzureConfig
     parser1 = create_runner_parser()
     parser1_result = parse_args_and_add_yaml_variables(parser1,
                                                        yaml_config_file=self.yaml_config_file,
                                                        project_root=self.project_root,
                                                        args=self.command_line_args,
                                                        fail_on_unknown_args=False)
     azure_config = AzureConfig(**parser1_result.args)
     azure_config.project_root = self.project_root
     self.azure_config = azure_config
     self.model_config = None  # type: ignore
     if not azure_config.model:
         return None
     model_config_loader: ModelConfigLoader = ModelConfigLoader(**parser1_result.args)
     # Create the model as per the "model" commandline option
     model_config = model_config_loader.create_model_config_from_name(
         model_name=azure_config.model
     )
     # This model will be either a classification model or a segmentation model. Those have different
     # fields that could be overridden on the command line. Create a parser that understands the fields we need
     # for the actual model type. We feed this parser will the YAML settings and commandline arguments that the
     # first parser did not recognize.
     parser2 = type(model_config).create_argparser()
     parser2_result = parse_arguments(parser2,
                                      settings_from_yaml=parser1_result.unknown_settings_from_yaml,
                                      args=parser1_result.unknown,
                                      fail_on_unknown_args=True)
     # Apply the overrides and validate. Overrides can come from either YAML settings or the commandline.
     model_config.apply_overrides(parser1_result.unknown_settings_from_yaml)
     model_config.apply_overrides(parser2_result.overrides)
     model_config.validate()
     # Set the file system related configs, they might be affected by the overrides that were applied.
     logging.info("Creating the adjusted output folder structure.")
     model_config.create_filesystem(self.project_root)
     if azure_config.extra_code_directory:
         exist = "exists" if Path(azure_config.extra_code_directory).exists() else "does not exist"
         logging.info(f"extra_code_directory is {azure_config.extra_code_directory}, which {exist}")
     else:
         logging.info("extra_code_directory is unset")
     self.model_config = model_config
     return parser2_result
예제 #6
0

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        help="The name of the model to test.",
                        type=empty_string_to_none,
                        required=True)
    parser.add_argument("--local_dataset",
                        help="Path to local dataset for testing",
                        type=string_to_path)
    parser.add_argument(
        "--outputs_folder",
        help="Path to outputs folder where checkpoints are stored",
        type=empty_string_to_none)
    parser.add_argument(
        "--test_series_ids",
        help="Subset of test cases for which the model testing is applied",
        nargs="+",
        type=int,
        required=False)
    parser.add_argument("--run_recovery_id",
                        help="Id of a run to recover from",
                        type=str,
                        required=False)

    args = parser.parse_args()
    test_config: ModelConfigBase = ModelConfigLoader(
    ).create_model_config_from_name(args.model, overrides=vars(args))
    model_test(config=test_config, data_split=ModelExecutionMode.TEST)