Example #1
0
    def generator_run_from_sqa(self,
                               generator_run_sqa: SQAGeneratorRun,
                               reduced_state: bool = False) -> GeneratorRun:
        """Convert SQLAlchemy GeneratorRun to Ax GeneratorRun.

        Args:
            generator_run_sqa: `SQAGeneratorRun` to decode.
            reduced_state: Whether to load generator runs with a slightly reduced state
            (without model state, search space, and optimization config).
        """
        arms = []
        weights = []
        opt_config = None
        search_space = None

        for arm_sqa in generator_run_sqa.arms:
            arms.append(self.arm_from_sqa(arm_sqa=arm_sqa))
            weights.append(arm_sqa.weight)

        if not reduced_state:
            (
                opt_config,
                tracking_metrics,
            ) = self.opt_config_and_tracking_metrics_from_sqa(
                metrics_sqa=generator_run_sqa.metrics)
            if len(tracking_metrics) > 0:
                raise SQADecodeError(  # pragma: no cover
                    "GeneratorRun should not have tracking metrics.")

            search_space = self.search_space_from_sqa(
                parameters_sqa=generator_run_sqa.parameters,
                parameter_constraints_sqa=generator_run_sqa.
                parameter_constraints,
            )

        best_arm_predictions = None
        model_predictions = None
        if (generator_run_sqa.best_arm_parameters is not None
                and generator_run_sqa.best_arm_predictions is not None):
            best_arm = Arm(
                name=generator_run_sqa.best_arm_name,
                parameters=not_none(generator_run_sqa.best_arm_parameters),
            )
            best_arm_predictions = (
                best_arm,
                tuple(not_none(generator_run_sqa.best_arm_predictions)),
            )
        model_predictions = (
            tuple(not_none(generator_run_sqa.model_predictions))
            if generator_run_sqa.model_predictions is not None else None)

        generator_run = GeneratorRun(
            arms=arms,
            weights=weights,
            optimization_config=opt_config,
            search_space=search_space,
            fit_time=generator_run_sqa.fit_time,
            gen_time=generator_run_sqa.gen_time,
            best_arm_predictions=best_arm_predictions,  # pyre-ignore[6]
            model_predictions=model_predictions,
            model_key=generator_run_sqa.model_key,
            model_kwargs=None if reduced_state else object_from_json(
                generator_run_sqa.model_kwargs),
            bridge_kwargs=None if reduced_state else object_from_json(
                generator_run_sqa.bridge_kwargs),
            gen_metadata=None if reduced_state else object_from_json(
                generator_run_sqa.gen_metadata),
            model_state_after_gen=None if reduced_state else object_from_json(
                generator_run_sqa.model_state_after_gen),
            generation_step_index=generator_run_sqa.generation_step_index,
            candidate_metadata_by_arm_signature=object_from_json(
                generator_run_sqa.candidate_metadata_by_arm_signature),
        )
        generator_run._time_created = generator_run_sqa.time_created
        generator_run._generator_run_type = self.get_enum_name(
            value=generator_run_sqa.generator_run_type,
            enum=self.config.generator_run_type_enum,
        )
        generator_run._index = generator_run_sqa.index
        generator_run.db_id = generator_run_sqa.id
        return generator_run
Example #2
0
    def generator_run_from_sqa(
        self,
        generator_run_sqa: SQAGeneratorRun,
        reduced_state: bool,
        immutable_search_space_and_opt_config: bool,
    ) -> GeneratorRun:
        """Convert SQLAlchemy GeneratorRun to Ax GeneratorRun.

        Args:
            generator_run_sqa: `SQAGeneratorRun` to decode.
            reduced_state: Whether to load generator runs with a slightly reduced state
                (without model state, search space, and optimization config).
            immutable_search_space_and_opt_config: Whether to load generator runs
                without search space and optimization config. Unlike `reduced_state`,
                we do still load model state.
        """
        arms = []
        weights = []
        opt_config = None
        search_space = None

        for arm_sqa in generator_run_sqa.arms:
            arms.append(self.arm_from_sqa(arm_sqa=arm_sqa))
            weights.append(arm_sqa.weight)

        if not reduced_state and not immutable_search_space_and_opt_config:
            (
                opt_config,
                tracking_metrics,
            ) = self.opt_config_and_tracking_metrics_from_sqa(
                metrics_sqa=generator_run_sqa.metrics)
            if len(tracking_metrics) > 0:
                raise SQADecodeError(  # pragma: no cover
                    "GeneratorRun should not have tracking metrics.")

            search_space = self.search_space_from_sqa(
                parameters_sqa=generator_run_sqa.parameters,
                parameter_constraints_sqa=generator_run_sqa.
                parameter_constraints,
            )

        best_arm_predictions = None
        model_predictions = None
        if (generator_run_sqa.best_arm_parameters is not None
                and generator_run_sqa.best_arm_predictions is not None):
            best_arm = Arm(
                name=generator_run_sqa.best_arm_name,
                parameters=not_none(generator_run_sqa.best_arm_parameters),
            )
            best_arm_predictions = (
                best_arm,
                tuple(not_none(generator_run_sqa.best_arm_predictions)),
            )
        model_predictions = (
            tuple(not_none(generator_run_sqa.model_predictions))
            if generator_run_sqa.model_predictions is not None else None)

        generator_run = GeneratorRun(
            arms=arms,
            weights=weights,
            optimization_config=opt_config,
            search_space=search_space,
            fit_time=generator_run_sqa.fit_time,
            gen_time=generator_run_sqa.gen_time,
            best_arm_predictions=best_arm_predictions,  # pyre-ignore[6]
            # pyre-fixme[6]: Expected `Optional[Tuple[typing.Dict[str, List[float]],
            #  typing.Dict[str, typing.Dict[str, List[float]]]]]` for 8th param but got
            #  `Optional[typing.Tuple[Union[typing.Dict[str, List[float]],
            #  typing.Dict[str, typing.Dict[str, List[float]]]], ...]]`.
            model_predictions=model_predictions,
            model_key=generator_run_sqa.model_key,
            model_kwargs=None if reduced_state else object_from_json(
                generator_run_sqa.model_kwargs,
                decoder_registry=self.config.json_decoder_registry,
                class_decoder_registry=self.config.json_class_decoder_registry,
            ),
            bridge_kwargs=None if reduced_state else object_from_json(
                generator_run_sqa.bridge_kwargs,
                decoder_registry=self.config.json_decoder_registry,
                class_decoder_registry=self.config.json_class_decoder_registry,
            ),
            gen_metadata=None if reduced_state else object_from_json(
                generator_run_sqa.gen_metadata,
                decoder_registry=self.config.json_decoder_registry,
                class_decoder_registry=self.config.json_class_decoder_registry,
            ),
            model_state_after_gen=None if reduced_state else object_from_json(
                generator_run_sqa.model_state_after_gen,
                decoder_registry=self.config.json_decoder_registry,
                class_decoder_registry=self.config.json_class_decoder_registry,
            ),
            generation_step_index=generator_run_sqa.generation_step_index,
            candidate_metadata_by_arm_signature=object_from_json(
                generator_run_sqa.candidate_metadata_by_arm_signature,
                decoder_registry=self.config.json_decoder_registry,
                class_decoder_registry=self.config.json_class_decoder_registry,
            ),
        )
        generator_run._time_created = generator_run_sqa.time_created
        generator_run._generator_run_type = self.get_enum_name(
            value=generator_run_sqa.generator_run_type,
            enum=self.config.generator_run_type_enum,
        )
        generator_run._index = generator_run_sqa.index
        generator_run.db_id = generator_run_sqa.id
        return generator_run