def generator_run_from_json( object_json: Dict[str, Any], decoder_registry: Dict[str, Type], class_decoder_registry: Dict[str, Callable[[Dict[str, Any]], Any]], ) -> GeneratorRun: """Load Ax GeneratorRun from JSON.""" time_created_json = object_json.pop("time_created") type_json = object_json.pop("generator_run_type") index_json = object_json.pop("index") generator_run = GeneratorRun( **{ k: object_from_json( v, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) for k, v in object_json.items() }) generator_run._time_created = object_from_json( time_created_json, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) generator_run._generator_run_type = object_from_json( type_json, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) generator_run._index = object_from_json( index_json, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) return generator_run
def generator_run_from_sqa( self, generator_run_sqa: SQAGeneratorRun) -> GeneratorRun: """Convert SQLAlchemy GeneratorRun to Ax GeneratorRun.""" arms = [] weights = [] opt_config = None search_space = None for arm_sqa in generator_run_sqa.arms: arms.append(self.arm_from_sqa(arm_sqa=arm_sqa)) weights.append(arm_sqa.weight) opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa( metrics_sqa=generator_run_sqa.metrics) if len(tracking_metrics) > 0: raise SQADecodeError( # pragma: no cover "GeneratorRun should not have tracking metrics.") search_space = self.search_space_from_sqa( parameters_sqa=generator_run_sqa.parameters, parameter_constraints_sqa=generator_run_sqa.parameter_constraints, ) best_arm_predictions = None model_predictions = None if (generator_run_sqa.best_arm_parameters is not None and generator_run_sqa.best_arm_predictions is not None): best_arm = Arm( name=generator_run_sqa.best_arm_name, parameters=generator_run_sqa.best_arm_parameters, ) best_arm_predictions = ( best_arm, tuple(generator_run_sqa.best_arm_predictions), ) model_predictions = (tuple(generator_run_sqa.model_predictions) if generator_run_sqa.model_predictions is not None else None) generator_run = GeneratorRun( arms=arms, weights=weights, optimization_config=opt_config, search_space=search_space, fit_time=generator_run_sqa.fit_time, gen_time=generator_run_sqa.gen_time, best_arm_predictions=best_arm_predictions, model_predictions=model_predictions, ) generator_run._time_created = generator_run_sqa.time_created generator_run._generator_run_type = self.get_enum_name( value=generator_run_sqa.generator_run_type, enum=self.config.generator_run_type_enum, ) generator_run._index = generator_run_sqa.index return generator_run
def generator_run_from_json(object_json: Dict[str, Any]) -> GeneratorRun: """Load Ax GeneratorRun from JSON.""" time_created_json = object_json.pop("time_created") type_json = object_json.pop("generator_run_type") index_json = object_json.pop("index") generator_run = GeneratorRun( **{k: object_from_json(v) for k, v in object_json.items()}) generator_run._time_created = object_from_json(time_created_json) generator_run._generator_run_type = object_from_json(type_json) generator_run._index = object_from_json(index_json) return generator_run
def generator_run_from_sqa(self, generator_run_sqa: SQAGeneratorRun, reduced_state: bool = False) -> GeneratorRun: """Convert SQLAlchemy GeneratorRun to Ax GeneratorRun. Args: generator_run_sqa: `SQAGeneratorRun` to decode. reduced_state: Whether to load generator runs with a slightly reduced state (without model state, search space, and optimization config). """ arms = [] weights = [] opt_config = None search_space = None for arm_sqa in generator_run_sqa.arms: arms.append(self.arm_from_sqa(arm_sqa=arm_sqa)) weights.append(arm_sqa.weight) if not reduced_state: ( opt_config, tracking_metrics, ) = self.opt_config_and_tracking_metrics_from_sqa( metrics_sqa=generator_run_sqa.metrics) if len(tracking_metrics) > 0: raise SQADecodeError( # pragma: no cover "GeneratorRun should not have tracking metrics.") search_space = self.search_space_from_sqa( parameters_sqa=generator_run_sqa.parameters, parameter_constraints_sqa=generator_run_sqa. parameter_constraints, ) best_arm_predictions = None model_predictions = None if (generator_run_sqa.best_arm_parameters is not None and generator_run_sqa.best_arm_predictions is not None): best_arm = Arm( name=generator_run_sqa.best_arm_name, parameters=not_none(generator_run_sqa.best_arm_parameters), ) best_arm_predictions = ( best_arm, tuple(not_none(generator_run_sqa.best_arm_predictions)), ) model_predictions = ( tuple(not_none(generator_run_sqa.model_predictions)) if generator_run_sqa.model_predictions is not None else None) generator_run = GeneratorRun( arms=arms, weights=weights, optimization_config=opt_config, search_space=search_space, fit_time=generator_run_sqa.fit_time, gen_time=generator_run_sqa.gen_time, best_arm_predictions=best_arm_predictions, # pyre-ignore[6] model_predictions=model_predictions, model_key=generator_run_sqa.model_key, model_kwargs=None if reduced_state else object_from_json( generator_run_sqa.model_kwargs), bridge_kwargs=None if reduced_state else object_from_json( generator_run_sqa.bridge_kwargs), gen_metadata=None if reduced_state else object_from_json( generator_run_sqa.gen_metadata), model_state_after_gen=None if reduced_state else object_from_json( generator_run_sqa.model_state_after_gen), generation_step_index=generator_run_sqa.generation_step_index, candidate_metadata_by_arm_signature=object_from_json( generator_run_sqa.candidate_metadata_by_arm_signature), ) generator_run._time_created = generator_run_sqa.time_created generator_run._generator_run_type = self.get_enum_name( value=generator_run_sqa.generator_run_type, enum=self.config.generator_run_type_enum, ) generator_run._index = generator_run_sqa.index generator_run.db_id = generator_run_sqa.id return generator_run
def generator_run_from_sqa( self, generator_run_sqa: SQAGeneratorRun ) -> GeneratorRun: """Convert SQLAlchemy GeneratorRun to Ax GeneratorRun.""" arms = [] weights = [] opt_config = None search_space = None for arm_sqa in generator_run_sqa.arms: arms.append(self.arm_from_sqa(arm_sqa=arm_sqa)) weights.append(arm_sqa.weight) opt_config, tracking_metrics = self.opt_config_and_tracking_metrics_from_sqa( metrics_sqa=generator_run_sqa.metrics ) if len(tracking_metrics) > 0: raise SQADecodeError( # pragma: no cover "GeneratorRun should not have tracking metrics." ) search_space = self.search_space_from_sqa( parameters_sqa=generator_run_sqa.parameters, parameter_constraints_sqa=generator_run_sqa.parameter_constraints, ) best_arm_predictions = None model_predictions = None if ( generator_run_sqa.best_arm_parameters is not None and generator_run_sqa.best_arm_predictions is not None ): best_arm = Arm( name=generator_run_sqa.best_arm_name, # pyre-fixme[6]: Expected `Dict[str, Optional[Union[bool, float, # int, str]]]` for 2nd param but got `Optional[Dict[str, # Optional[Union[bool, float, int, str]]]]`. parameters=generator_run_sqa.best_arm_parameters, ) best_arm_predictions = ( best_arm, # pyre-fixme[6]: Expected `Iterable[_T_co]` for 1st param but got # `Optional[Tuple[Dict[str, float], Optional[Dict[str, Dict[str, # float]]]]]`. tuple(generator_run_sqa.best_arm_predictions), ) model_predictions = ( # pyre-fixme[6]: Expected `Iterable[_T_co]` for 1st param but got # `Optional[Tuple[Dict[str, List[float]], Dict[str, Dict[str, # List[float]]]]]`. tuple(generator_run_sqa.model_predictions) if generator_run_sqa.model_predictions is not None else None ) generator_run = GeneratorRun( arms=arms, weights=weights, optimization_config=opt_config, search_space=search_space, fit_time=generator_run_sqa.fit_time, gen_time=generator_run_sqa.gen_time, # pyre-fixme[6]: Expected `Optional[Tuple[Arm, Optional[Tuple[Dict[str, # float], Optional[Dict[str, Dict[str, float]]]]]]]` for 7th param but got # `Optional[Tuple[Arm, Tuple[Any, ...]]]`. best_arm_predictions=best_arm_predictions, model_predictions=model_predictions, model_key=generator_run_sqa.model_key, model_kwargs=object_from_json(generator_run_sqa.model_kwargs), bridge_kwargs=object_from_json(generator_run_sqa.bridge_kwargs), gen_metadata=object_from_json(generator_run_sqa.gen_metadata), model_state_after_gen=object_from_json( generator_run_sqa.model_state_after_gen ), generation_step_index=generator_run_sqa.generation_step_index, candidate_metadata_by_arm_signature=object_from_json( generator_run_sqa.candidate_metadata_by_arm_signature ), ) generator_run._time_created = generator_run_sqa.time_created generator_run._generator_run_type = self.get_enum_name( value=generator_run_sqa.generator_run_type, enum=self.config.generator_run_type_enum, ) generator_run._index = generator_run_sqa.index return generator_run
def generator_run_from_sqa( self, generator_run_sqa: SQAGeneratorRun, reduced_state: bool, immutable_search_space_and_opt_config: bool, ) -> GeneratorRun: """Convert SQLAlchemy GeneratorRun to Ax GeneratorRun. Args: generator_run_sqa: `SQAGeneratorRun` to decode. reduced_state: Whether to load generator runs with a slightly reduced state (without model state, search space, and optimization config). immutable_search_space_and_opt_config: Whether to load generator runs without search space and optimization config. Unlike `reduced_state`, we do still load model state. """ arms = [] weights = [] opt_config = None search_space = None for arm_sqa in generator_run_sqa.arms: arms.append(self.arm_from_sqa(arm_sqa=arm_sqa)) weights.append(arm_sqa.weight) if not reduced_state and not immutable_search_space_and_opt_config: ( opt_config, tracking_metrics, ) = self.opt_config_and_tracking_metrics_from_sqa( metrics_sqa=generator_run_sqa.metrics) if len(tracking_metrics) > 0: raise SQADecodeError( # pragma: no cover "GeneratorRun should not have tracking metrics.") search_space = self.search_space_from_sqa( parameters_sqa=generator_run_sqa.parameters, parameter_constraints_sqa=generator_run_sqa. parameter_constraints, ) best_arm_predictions = None model_predictions = None if (generator_run_sqa.best_arm_parameters is not None and generator_run_sqa.best_arm_predictions is not None): best_arm = Arm( name=generator_run_sqa.best_arm_name, parameters=not_none(generator_run_sqa.best_arm_parameters), ) best_arm_predictions = ( best_arm, tuple(not_none(generator_run_sqa.best_arm_predictions)), ) model_predictions = ( tuple(not_none(generator_run_sqa.model_predictions)) if generator_run_sqa.model_predictions is not None else None) generator_run = GeneratorRun( arms=arms, weights=weights, optimization_config=opt_config, search_space=search_space, fit_time=generator_run_sqa.fit_time, gen_time=generator_run_sqa.gen_time, best_arm_predictions=best_arm_predictions, # pyre-ignore[6] # pyre-fixme[6]: Expected `Optional[Tuple[typing.Dict[str, List[float]], # typing.Dict[str, typing.Dict[str, List[float]]]]]` for 8th param but got # `Optional[typing.Tuple[Union[typing.Dict[str, List[float]], # typing.Dict[str, typing.Dict[str, List[float]]]], ...]]`. model_predictions=model_predictions, model_key=generator_run_sqa.model_key, model_kwargs=None if reduced_state else object_from_json( generator_run_sqa.model_kwargs, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), bridge_kwargs=None if reduced_state else object_from_json( generator_run_sqa.bridge_kwargs, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), gen_metadata=None if reduced_state else object_from_json( generator_run_sqa.gen_metadata, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), model_state_after_gen=None if reduced_state else object_from_json( generator_run_sqa.model_state_after_gen, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), generation_step_index=generator_run_sqa.generation_step_index, candidate_metadata_by_arm_signature=object_from_json( generator_run_sqa.candidate_metadata_by_arm_signature, decoder_registry=self.config.json_decoder_registry, class_decoder_registry=self.config.json_class_decoder_registry, ), ) generator_run._time_created = generator_run_sqa.time_created generator_run._generator_run_type = self.get_enum_name( value=generator_run_sqa.generator_run_type, enum=self.config.generator_run_type_enum, ) generator_run._index = generator_run_sqa.index generator_run.db_id = generator_run_sqa.id return generator_run