def generator_run_from_sqa(self, generator_run_sqa: SQAGeneratorRun, reduced_state: bool = False) -> GeneratorRun: """Convert SQLAlchemy GeneratorRun to Ax GeneratorRun. Args: generator_run_sqa: `SQAGeneratorRun` to decode. reduced_state: Whether to load generator runs with a slightly reduced state (without model state, search space, and optimization config). """ arms = [] weights = [] opt_config = None search_space = None for arm_sqa in generator_run_sqa.arms: arms.append(self.arm_from_sqa(arm_sqa=arm_sqa)) weights.append(arm_sqa.weight) if not reduced_state: ( opt_config, tracking_metrics, ) = self.opt_config_and_tracking_metrics_from_sqa( metrics_sqa=generator_run_sqa.metrics) if len(tracking_metrics) > 0: raise SQADecodeError( # pragma: no cover "GeneratorRun should not have tracking metrics.") search_space = self.search_space_from_sqa( parameters_sqa=generator_run_sqa.parameters, parameter_constraints_sqa=generator_run_sqa. parameter_constraints, ) best_arm_predictions = None model_predictions = None if (generator_run_sqa.best_arm_parameters is not None and generator_run_sqa.best_arm_predictions is not None): best_arm = Arm( name=generator_run_sqa.best_arm_name, parameters=not_none(generator_run_sqa.best_arm_parameters), ) best_arm_predictions = ( best_arm, tuple(not_none(generator_run_sqa.best_arm_predictions)), ) model_predictions = ( tuple(not_none(generator_run_sqa.model_predictions)) if generator_run_sqa.model_predictions is not None else None) generator_run = GeneratorRun( arms=arms, weights=weights, optimization_config=opt_config, search_space=search_space, fit_time=generator_run_sqa.fit_time, gen_time=generator_run_sqa.gen_time, best_arm_predictions=best_arm_predictions, # pyre-ignore[6] model_predictions=model_predictions, model_key=generator_run_sqa.model_key, model_kwargs=None if reduced_state else object_from_json( generator_run_sqa.model_kwargs), bridge_kwargs=None if reduced_state else object_from_json( generator_run_sqa.bridge_kwargs), gen_metadata=None if reduced_state else object_from_json( generator_run_sqa.gen_metadata), model_state_after_gen=None if reduced_state else object_from_json( generator_run_sqa.model_state_after_gen), generation_step_index=generator_run_sqa.generation_step_index, candidate_metadata_by_arm_signature=object_from_json( generator_run_sqa.candidate_metadata_by_arm_signature), ) generator_run._time_created = generator_run_sqa.time_created generator_run._generator_run_type = self.get_enum_name( value=generator_run_sqa.generator_run_type, enum=self.config.generator_run_type_enum, ) generator_run._index = generator_run_sqa.index generator_run.db_id = generator_run_sqa.id return generator_run
def generator_run_from_sqa( self, generator_run_sqa: SQAGeneratorRun, reduced_state: bool, immutable_search_space_and_opt_config: bool, ) -> GeneratorRun: """Convert SQLAlchemy GeneratorRun to Ax GeneratorRun. Args: generator_run_sqa: `SQAGeneratorRun` to decode. reduced_state: Whether to load generator runs with a slightly reduced state (without model state, search space, and optimization config). immutable_search_space_and_opt_config: Whether to load generator runs without search space and optimization config. Unlike `reduced_state`, we do still load model state. """ arms = [] weights = [] opt_config = None search_space = None for arm_sqa in generator_run_sqa.arms: arms.append(self.arm_from_sqa(arm_sqa=arm_sqa)) weights.append(arm_sqa.weight) if not reduced_state and not immutable_search_space_and_opt_config: ( opt_config, tracking_metrics, ) = self.opt_config_and_tracking_metrics_from_sqa( metrics_sqa=generator_run_sqa.metrics) if len(tracking_metrics) > 0: raise SQADecodeError( # pragma: no cover "GeneratorRun should not have tracking metrics.") search_space = self.search_space_from_sqa( parameters_sqa=generator_run_sqa.parameters, parameter_constraints_sqa=generator_run_sqa. parameter_constraints, ) best_arm_predictions = None model_predictions = None if (generator_run_sqa.best_arm_parameters is not None and generator_run_sqa.best_arm_predictions is not None): best_arm = Arm( name=generator_run_sqa.best_arm_name, parameters=not_none(generator_run_sqa.best_arm_parameters), ) best_arm_predictions = ( best_arm, tuple(not_none(generator_run_sqa.best_arm_predictions)), ) model_predictions = ( tuple(not_none(generator_run_sqa.model_predictions)) if generator_run_sqa.model_predictions is not None else None) generator_run = GeneratorRun( arms=arms, weights=weights, optimization_config=opt_config, search_space=search_space, fit_time=generator_run_sqa.fit_time, gen_time=generator_run_sqa.gen_time, best_arm_predictions=best_arm_predictions, # pyre-ignore[6] # pyre-fixme[6]: Expected `Optional[Tuple[typing.Dict[str, List[float]], # typing.Dict[str, typing.Dict[str, List[float]]]]]` for 8th param but got # `Optional[typing.Tuple[Union[typing.Dict[str, List[float]], # typing.Dict[str, typing.Dict[str, List[float]]]], ...]]`. model_predictions=model_predictions, model_key=generator_run_sqa.model_key, model_kwargs=None if reduced_state else object_from_json( generator_run_sqa.model_kwargs, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), bridge_kwargs=None if reduced_state else object_from_json( generator_run_sqa.bridge_kwargs, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), gen_metadata=None if reduced_state else object_from_json( generator_run_sqa.gen_metadata, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), model_state_after_gen=None if reduced_state else object_from_json( generator_run_sqa.model_state_after_gen, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), generation_step_index=generator_run_sqa.generation_step_index, candidate_metadata_by_arm_signature=object_from_json( generator_run_sqa.candidate_metadata_by_arm_signature, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), ) generator_run._time_created = generator_run_sqa.time_created generator_run._generator_run_type = self.get_enum_name( value=generator_run_sqa.generator_run_type, enum=self.config.generator_run_type_enum, ) generator_run._index = generator_run_sqa.index generator_run.db_id = generator_run_sqa.id return generator_run