コード例 #1
0
ファイル: run.py プロジェクト: Vivek305/fastestimator
def run(args: Dict[str, Any], unknown: Optional[List[str]]) -> None:
    """Invoke the fastestimator_run function from a file.

    Args:
        args: A dictionary containing location of the FE file under the 'entry_point' key, as well as an optional
            'hyperparameters_json' key if the user is storing their parameters in a file.
        unknown: The remainder of the command line arguments to be passed along to the fastestimator_run() method.
    """
    entry_point = args['entry_point']
    hyperparameters = {}
    if args['hyperparameters_json']:
        hyperparameters = os.path.abspath(args['hyperparameters_json'])
        with open(hyperparameters, 'r') as f:
            hyperparameters = json.load(f)
    hyperparameters.update(parse_cli_to_dictionary(unknown))
    module_name = os.path.splitext(os.path.basename(entry_point))[0]
    dir_name = os.path.abspath(os.path.dirname(entry_point))
    sys.path.insert(0, dir_name)
    spec_module = __import__(module_name, globals(), locals())
    if hasattr(spec_module, "fastestimator_run"):
        spec_module.fastestimator_run(**hyperparameters)
    elif hasattr(spec_module, "get_estimator"):
        est = spec_module.get_estimator(**hyperparameters)
        if "train" in est.pipeline.data:
            est.fit()
        if "test" in est.pipeline.data:
            est.test()
    else:
        raise ValueError(
            "The file {} does not contain 'fastestimator_run' or 'get_estimator'"
            .format(module_name))
コード例 #2
0
ファイル: train.py プロジェクト: Vivek305/fastestimator
def _get_estimator(args: Dict[str, Any], unknown: Optional[List[str]]) -> Estimator:
    """A helper method to invoke the get_estimator method from a file using provided command line arguments as input.

    Args:
        args: A dictionary containing location of the FE file under the 'entry_point' key, as well as an optional
            'hyperparameters_json' key if the user is storing their parameters in a file.
        unknown: The remainder of the command line arguments to be passed along to the get_estimator() method.

    Returns:
        The estimator generated by a file's get_estimator() function.
    """
    entry_point = args['entry_point']
    hyperparameters = {}
    if args['hyperparameters_json']:
        hyperparameters = os.path.abspath(args['hyperparameters_json'])
        with open(hyperparameters, 'r') as f:
            hyperparameters = json.load(f)
    hyperparameters.update(parse_cli_to_dictionary(unknown))
    module_name = os.path.splitext(os.path.basename(entry_point))[0]
    dir_name = os.path.abspath(os.path.dirname(entry_point))
    sys.path.insert(0, dir_name)
    spec_module = __import__(module_name, globals(), locals(), ["get_estimator"])
    return spec_module.get_estimator(**hyperparameters)
コード例 #3
0
 def test_parse_cli_to_dictionary_consecutive_value(self):
     a = parse_cli_to_dictionary(["--abc", "hello", "world"])
     self.assertEqual(a, {"abc": "helloworld"})
コード例 #4
0
 def test_parse_cli_to_dictionary_no_value(self):
     a = parse_cli_to_dictionary(["--abc", "--def"])
     self.assertEqual(a, {"abc": "", "def": ""})
コード例 #5
0
 def test_parse_cli_to_dictionary_no_key(self):
     a = parse_cli_to_dictionary(["abc", "def"])
     self.assertEqual(a, {})
コード例 #6
0
 def test_parse_cli_to_dictionary(self):
     a = parse_cli_to_dictionary(
         ["--epochs", "5", "--test", "this", "--lr", "0.74"])
     self.assertEqual(a, {'epochs': 5, 'test': 'this', 'lr': 0.74})