예제 #1
0
    def from_paths(
        cls,
        checkpoints: List[str],
        auto_update_checkpoints=True,
        predictor_params: PredictorParams = None,
        voter_params: VoterParams = None,
        **kwargs,
    ) -> "tfaip_cls.MultiModelPredictor":
        if not checkpoints:
            raise Exception("No checkpoints provided.")

        if predictor_params is None:
            predictor_params = PredictorParams(silent=True, progress_bar=True)

        DeviceConfig(predictor_params.device)
        checkpoints = [
            SavedCalamariModel(ckpt, auto_update=auto_update_checkpoints)
            for ckpt in checkpoints
        ]
        multi_predictor = super(MultiPredictor, cls).from_paths(
            [ckpt.json_path for ckpt in checkpoints],
            predictor_params,
            CalamariScenario,
            model_paths=[ckpt.ckpt_path + ".h5" for ckpt in checkpoints],
            predictor_args={"voter_params": voter_params},
        )

        return multi_predictor
예제 #2
0
def setup_test_init():
    """Function that should be called in the root __init__ of all tests

    The call ensures that the Devices and the logging are initialized correctly
    """

    DeviceConfig(DeviceConfigParams())

    logging.logger(__name__).debug("Set up device config for testing")
예제 #3
0
 def __init__(self, params: PredictorParams, data: "DataBase", **kwargs):
     assert len(
         kwargs) == 0, f"Not all kwargs processed by subclasses: {kwargs}"
     post_init(params)
     self._params = params
     self._params.pipeline.mode = PipelineMode.PREDICTION
     if params.include_targets:
         self._params.pipeline.mode = PipelineMode.EVALUATION
     self.device_config = DeviceConfig(self._params.device)
     self._data = data
     self.benchmark_results = BenchmarkResults()
     self._keras_model: Optional[keras.Model] = None
예제 #4
0
 def from_checkpoint(params: PredictorParams,
                     checkpoint: str,
                     auto_update_checkpoints=True):
     DeviceConfig(params.device)  # Device must be specified first
     ckpt = SavedCalamariModel(checkpoint,
                               auto_update=auto_update_checkpoints)
     scenario_params = CalamariScenario.params_from_dict(ckpt.dict)
     scenario = CalamariScenario(scenario_params)
     predictor = Predictor(params, scenario.create_data())
     predictor.set_model(
         keras.models.load_model(ckpt.ckpt_path + '.h5',
                                 custom_objects=CalamariScenario.model_cls(
                                 ).all_custom_objects()))
     return predictor
예제 #5
0
 def __init__(
     self,
     params: LAVParams,
     scenario_params: ScenarioBaseParams,
     **kwargs,
 ):
     assert len(
         kwargs) == 0, f"Not all kwargs processed by subclasses: {kwargs}"
     assert params.model_path
     self._scenario_cls = scenario_params.cls()
     self._scenario_params = scenario_params
     self._params = params
     self.device_config = DeviceConfig(self._params.device)
     self._data: Optional["DataBase"] = None
     self._model: Optional["ModelBase"] = None
     self.benchmark_results = BenchmarkResults()
예제 #6
0
 def __init__(
     self,
     params: LAVParams,
     scenario_params: ScenarioBaseParams,
     predictor_params: PredictorParams,
     **kwargs,
 ):
     assert len(kwargs) == 0, f"Not all kwargs processed by subclasses: {kwargs}"
     assert params.model_path
     self._scenario_cls = scenario_params.cls()
     self._scenario_params = scenario_params
     self._params = params
     self.device_config = DeviceConfig(self._params.device)
     self.benchmark_results = BenchmarkResults()
     self.predictor_params = predictor_params
     self.predictor_params.pipeline = self._params.pipeline
     predictor_params.silent = True
     predictor_params.progress_bar = True
     predictor_params.include_targets = True
예제 #7
0
    def __init__(self, params: TTrainerParams, scenario: ScenarioBase, restore=False):
        super().__init__()
        self._params = params
        self._training_graph_only = params.export_training_graph_path is not None  # only required for JAVA-Training
        if self._training_graph_only:
            EXPORT_TENSORFLOW_1["metric_aggregation"] = "mean"
            tf.compat.v1.disable_eager_execution()
            self.params.export_final = False
            self.params.scenario.print_eval_limit = 0

        self.restore = restore
        if self._params.random_seed is not None:
            set_global_random_seed(self._params.random_seed)

        if restore and not self._params.output_dir:
            raise ValueError("To restore a training, a checkpoint dir must be provided")

        self.device_config = DeviceConfig(self._params.device)

        # default value of export best shall be true if a checkpoint dir is provided
        # if the user manually sets it to true, a checkpoint dir must be provided
        if params.export_best is None:
            params.export_best = params.output_dir is not None
        if params.export_best and not params.output_dir:
            raise ValueError("To use 'export_best' a 'output_dir' must be specified")
        if self._params.output_dir:
            scenario.params.id = os.path.basename(self._params.output_dir) + "_"
        else:
            scenario.params.id = ""
        scenario.params.id = (
            scenario.params.id + scenario.params.scenario_id + "_" + datetime.today().strftime("%Y-%m-%d")
        )

        self._scenario = scenario
        self.stop_training = False
        self._steps_per_epoch: Optional[int] = None  # Not initialized yet
        self._callbacks = []
        self._data = None
        self._model = None

        os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(self._params.tf_cpp_min_log_level)
예제 #8
0
 def setUp(self) -> None:
     # Setup device
     DeviceConfig(DeviceConfigParams())