Exemplo n.º 1
0
Arquivo: decoder.py Projeto: raff7/Ax
 def trial_from_sqa(self, trial_sqa: SQATrial,
                    experiment: Experiment) -> BaseTrial:
     """Convert SQLAlchemy Trial to Ax Trial."""
     if trial_sqa.is_batch:
         trial = BatchTrial(experiment=experiment,
                            optimize_for_power=trial_sqa.optimize_for_power)
         generator_run_structs = [
             GeneratorRunStruct(
                 generator_run=self.generator_run_from_sqa(
                     generator_run_sqa=generator_run_sqa),
                 weight=generator_run_sqa.weight or 1.0,
             ) for generator_run_sqa in trial_sqa.generator_runs
         ]
         if trial_sqa.status_quo_name is not None:
             new_generator_run_structs = []
             for struct in generator_run_structs:
                 if (struct.generator_run.generator_run_type ==
                         GeneratorRunType.STATUS_QUO.name):
                     status_quo_weight = struct.generator_run.weights[0]
                     trial._status_quo = struct.generator_run.arms[0]
                     trial._status_quo_weight_override = status_quo_weight
                 else:
                     new_generator_run_structs.append(struct)
             generator_run_structs = new_generator_run_structs
         trial._generator_run_structs = generator_run_structs
         trial._abandoned_arms_metadata = {
             abandoned_arm_sqa.name: self.abandoned_arm_from_sqa(
                 abandoned_arm_sqa=abandoned_arm_sqa)
             for abandoned_arm_sqa in trial_sqa.abandoned_arms
         }
     else:
         trial = Trial(experiment=experiment)
         if trial_sqa.generator_runs:
             if len(trial_sqa.generator_runs) != 1:
                 raise SQADecodeError(  # pragma: no cover
                     "Cannot decode SQATrial to Trial because trial is not batched "
                     "but has more than one generator run.")
             trial._generator_run = self.generator_run_from_sqa(
                 generator_run_sqa=trial_sqa.generator_runs[0])
     trial._index = trial_sqa.index
     trial._trial_type = trial_sqa.trial_type
     # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly
     # equivalent to `RUNNING`.
     trial._status = (trial_sqa.status
                      if trial_sqa.status != TrialStatus.DISPATCHED else
                      TrialStatus.RUNNING)
     trial._time_created = trial_sqa.time_created
     trial._time_completed = trial_sqa.time_completed
     trial._time_staged = trial_sqa.time_staged
     trial._time_run_started = trial_sqa.time_run_started
     trial._abandoned_reason = trial_sqa.abandoned_reason
     # pyre-fixme[9]: _run_metadata has type `Dict[str, Any]`; used as
     #  `Optional[Dict[str, Any]]`.
     trial._run_metadata = (dict(trial_sqa.run_metadata)
                            if trial_sqa.run_metadata is not None else None)
     trial._num_arms_created = trial_sqa.num_arms_created
     trial._runner = (self.runner_from_sqa(trial_sqa.runner)
                      if trial_sqa.runner else None)
     return trial
Exemplo n.º 2
0
    def trial_from_sqa(self,
                       trial_sqa: SQATrial,
                       experiment: Experiment,
                       reduced_state: bool = False) -> BaseTrial:
        """Convert SQLAlchemy Trial to Ax Trial.

        Args:
            trial_sqa: `SQATrial` to decode.
            reduced_state: Whether to load trial's generator run(s) with a slightly
            reduced state (without model state, search space, and optimization config).

        """
        if trial_sqa.is_batch:
            trial = BatchTrial(
                experiment=experiment,
                optimize_for_power=trial_sqa.optimize_for_power,
                ttl_seconds=trial_sqa.ttl_seconds,
                index=trial_sqa.index,
            )
            generator_run_structs = [
                GeneratorRunStruct(
                    generator_run=self.generator_run_from_sqa(
                        generator_run_sqa=generator_run_sqa,
                        reduced_state=reduced_state,
                    ),
                    weight=generator_run_sqa.weight or 1.0,
                ) for generator_run_sqa in trial_sqa.generator_runs
            ]
            if trial_sqa.status_quo_name is not None:
                new_generator_run_structs = []
                for struct in generator_run_structs:
                    if (struct.generator_run.generator_run_type ==
                            GeneratorRunType.STATUS_QUO.name):
                        status_quo_weight = struct.generator_run.weights[0]
                        trial._status_quo = struct.generator_run.arms[0]
                        trial._status_quo_weight_override = status_quo_weight
                    else:
                        new_generator_run_structs.append(struct)
                generator_run_structs = new_generator_run_structs
            trial._generator_run_structs = generator_run_structs
            if not reduced_state:
                trial._abandoned_arms_metadata = {
                    abandoned_arm_sqa.name: self.abandoned_arm_from_sqa(
                        abandoned_arm_sqa=abandoned_arm_sqa)
                    for abandoned_arm_sqa in trial_sqa.abandoned_arms
                }
            trial._refresh_arms_by_name()  # Trigger cache build
        else:
            trial = Trial(
                experiment=experiment,
                ttl_seconds=trial_sqa.ttl_seconds,
                index=trial_sqa.index,
            )
            if trial_sqa.generator_runs:
                if len(trial_sqa.generator_runs) != 1:
                    raise SQADecodeError(  # pragma: no cover
                        "Cannot decode SQATrial to Trial because trial is not batched "
                        "but has more than one generator run.")
                trial._generator_run = self.generator_run_from_sqa(
                    generator_run_sqa=trial_sqa.generator_runs[0],
                    reduced_state=reduced_state,
                )
        trial._trial_type = trial_sqa.trial_type
        # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly
        # equivalent to `RUNNING`.
        trial._status = (trial_sqa.status
                         if trial_sqa.status != TrialStatus.DISPATCHED else
                         TrialStatus.RUNNING)
        trial._time_created = trial_sqa.time_created
        trial._time_completed = trial_sqa.time_completed
        trial._time_staged = trial_sqa.time_staged
        trial._time_run_started = trial_sqa.time_run_started
        trial._abandoned_reason = trial_sqa.abandoned_reason
        # pyre-fixme[9]: _run_metadata has type `Dict[str, Any]`; used as
        #  `Optional[Dict[str, Any]]`.
        # pyre-fixme[8]: Attribute has type `Dict[str, typing.Any]`; used as
        #  `Optional[typing.Dict[Variable[_KT], Variable[_VT]]]`.
        trial._run_metadata = (
            # pyre-fixme[6]: Expected `Mapping[Variable[_KT], Variable[_VT]]` for
            #  1st param but got `Optional[Dict[str, typing.Any]]`.
            dict(trial_sqa.run_metadata)
            if trial_sqa.run_metadata is not None else None)
        trial._num_arms_created = trial_sqa.num_arms_created
        trial._runner = (self.runner_from_sqa(trial_sqa.runner)
                         if trial_sqa.runner else None)
        trial._generation_step_index = trial_sqa.generation_step_index
        trial._properties = trial_sqa.properties or {}
        trial.db_id = trial_sqa.id
        return trial