def run(args: Dict[str, Any], unknown: Optional[List[str]]) -> None: """Invoke the fastestimator_run function from a file. Args: args: A dictionary containing location of the FE file under the 'entry_point' key, as well as an optional 'hyperparameters_json' key if the user is storing their parameters in a file. unknown: The remainder of the command line arguments to be passed along to the fastestimator_run() method. """ entry_point = args['entry_point'] hyperparameters = {} if args['hyperparameters_json']: hyperparameters = os.path.abspath(args['hyperparameters_json']) with open(hyperparameters, 'r') as f: hyperparameters = json.load(f) hyperparameters.update(parse_cli_to_dictionary(unknown)) module_name = os.path.splitext(os.path.basename(entry_point))[0] dir_name = os.path.abspath(os.path.dirname(entry_point)) sys.path.insert(0, dir_name) spec_module = __import__(module_name, globals(), locals()) if hasattr(spec_module, "fastestimator_run"): spec_module.fastestimator_run(**hyperparameters) elif hasattr(spec_module, "get_estimator"): est = spec_module.get_estimator(**hyperparameters) if "train" in est.pipeline.data: est.fit() if "test" in est.pipeline.data: est.test() else: raise ValueError( "The file {} does not contain 'fastestimator_run' or 'get_estimator'" .format(module_name))
def _get_estimator(args: Dict[str, Any], unknown: Optional[List[str]]) -> Estimator: """A helper method to invoke the get_estimator method from a file using provided command line arguments as input. Args: args: A dictionary containing location of the FE file under the 'entry_point' key, as well as an optional 'hyperparameters_json' key if the user is storing their parameters in a file. unknown: The remainder of the command line arguments to be passed along to the get_estimator() method. Returns: The estimator generated by a file's get_estimator() function. """ entry_point = args['entry_point'] hyperparameters = {} if args['hyperparameters_json']: hyperparameters = os.path.abspath(args['hyperparameters_json']) with open(hyperparameters, 'r') as f: hyperparameters = json.load(f) hyperparameters.update(parse_cli_to_dictionary(unknown)) module_name = os.path.splitext(os.path.basename(entry_point))[0] dir_name = os.path.abspath(os.path.dirname(entry_point)) sys.path.insert(0, dir_name) spec_module = __import__(module_name, globals(), locals(), ["get_estimator"]) return spec_module.get_estimator(**hyperparameters)
def test_parse_cli_to_dictionary_consecutive_value(self): a = parse_cli_to_dictionary(["--abc", "hello", "world"]) self.assertEqual(a, {"abc": "helloworld"})
def test_parse_cli_to_dictionary_no_value(self): a = parse_cli_to_dictionary(["--abc", "--def"]) self.assertEqual(a, {"abc": "", "def": ""})
def test_parse_cli_to_dictionary_no_key(self): a = parse_cli_to_dictionary(["abc", "def"]) self.assertEqual(a, {})
def test_parse_cli_to_dictionary(self): a = parse_cli_to_dictionary( ["--epochs", "5", "--test", "this", "--lr", "0.74"]) self.assertEqual(a, {'epochs': 5, 'test': 'this', 'lr': 0.74})