コード例 #1
0
 def __call__(self,
              parser: TFAIPArgumentParser,
              namespace,
              values,
              option_string=None):
     trainer_params, scenario = Trainer.parse_trainer_params(values)
     parser.add_root_argument("trainer",
                              trainer_params.__class__,
                              default=trainer_params)
     setattr(namespace, "scenario_cls", scenario)
コード例 #2
0
    def __call__(self,
                 parser: TFAIPArgumentParser,
                 namespace,
                 values,
                 option_string=None):
        scenario, scenario_params = ScenarioBase.from_path(values)
        evaluator_params = scenario_params.evaluator

        parser = TFAIPArgumentParser()
        parser.add_root_argument("evaluator", evaluator_params.__class__,
                                 evaluator_params)
        setattr(namespace, self.dest, scenario)
        setattr(namespace, self.dest + "_params", scenario_params)
コード例 #3
0
    def __call__(self,
                 parser: TFAIPArgumentParser,
                 namespace,
                 values,
                 option_string=None):
        from tfaip.scenario.scenariobase import import_scenario

        scenario = import_scenario(values)

        # Now pass the real args of the scenario
        default_trainer_params = scenario.default_trainer_params()
        parser.add_root_argument("trainer",
                                 default_trainer_params.__class__,
                                 default=default_trainer_params)
        setattr(namespace, self.dest, scenario)
コード例 #4
0
    def __call__(self,
                 parser: TFAIPArgumentParser,
                 namespace,
                 values,
                 option_string=None):
        from tfaip.imports import ScenarioBase

        export_dir = values
        scenario, scenario_params = ScenarioBase.from_path(export_dir)

        parser.add_root_argument("scenario",
                                 scenario_params.__class__,
                                 default=scenario_params)
        setattr(namespace, self.dest, values)
        setattr(namespace, "scenario_cls", scenario)
コード例 #5
0
    def __call__(self,
                 parser: TFAIPArgumentParser,
                 namespace,
                 values,
                 option_string=None):
        from tfaip.imports import Trainer

        output_dir = values
        trainer_params, scenario = Trainer.parse_trainer_params(output_dir)

        # parse additional args
        parser.add_root_argument("trainer",
                                 trainer_params.__class__,
                                 default=trainer_params)
        setattr(namespace, self.dest, values)
        setattr(namespace, "scenario_cls", scenario)
コード例 #6
0
ファイル: predict.py プロジェクト: Planet-AI-GmbH/tfaip
    def __call__(self,
                 parser: TFAIPArgumentParser,
                 namespace,
                 values,
                 option_string=None):
        scenario, scenario_params = ScenarioBase.from_path(values[0])
        predict_params = scenario.predictor_cls().params_cls()()
        parser.add_root_argument("data",
                                 scenario.predict_generator_params_cls())
        parser.add_root_argument("predict",
                                 predict_params.__class__,
                                 default=predict_params)

        setattr(namespace, self.dest, values)
        setattr(namespace, "scenario", scenario)
        setattr(namespace, "scenario_params", scenario_params)
コード例 #7
0
    def __call__(self,
                 parser: TFAIPArgumentParser,
                 namespace,
                 values,
                 option_string=None):
        from tfaip.imports import ScenarioBase

        export_dirs = values
        scenario, scenario_params = ScenarioBase.from_path(
            export_dirs[0])  # scenario based on first model
        lav_params = scenario.lav_cls().params_cls()()
        lav_params.model_path = export_dirs
        predictor_params = scenario.multi_predictor_cls().params_cls()()

        parser.add_root_argument("lav",
                                 scenario.lav_cls().params_cls(),
                                 default=lav_params)
        parser.add_root_argument("predictor",
                                 scenario.multi_predictor_cls().params_cls(),
                                 default=predictor_params,
                                 ignore=["pipeline"]),
        parser.add_root_argument("data",
                                 scenario.predict_generator_params_cls())

        setattr(namespace, self.dest, values)
        setattr(namespace, "scenario", scenario)
        setattr(namespace, "scenario_params", scenario_params)
コード例 #8
0
ファイル: lav.py プロジェクト: Planet-AI-GmbH/tfaip
    def __call__(self, parser: TFAIPArgumentParser, namespace, values, option_string=None):
        from tfaip.imports import ScenarioBase

        export_dir = values
        scenario, scenario_params = ScenarioBase.from_path(export_dir)

        default_gen_params = scenario.predict_generator_params_cls()()
        if os.path.exists(os.path.join(export_dir, "trainer_params.json")):
            # if trainer_params exist load val generator as default
            with open(os.path.join(export_dir, "trainer_params.json")) as f:
                p = scenario.trainer_cls().params_cls().from_json(f.read())
                default_gen_params = p.gen.lav_gen()[0]

        lav_params = scenario.lav_cls().params_cls()()
        lav_params.model_path = export_dir

        parser.add_root_argument("data", DataGeneratorParams, default=default_gen_params)
        parser.add_root_argument("lav", scenario.lav_cls().params_cls(), default=lav_params)

        setattr(namespace, self.dest, values)
        setattr(namespace, "scenario", scenario)
        setattr(namespace, "scenario_params", scenario_params)