def testSortable(self): new_batch_trial = self.experiment.new_batch_trial() self.assertTrue(self.batch < new_batch_trial) abandoned_arm = get_abandoned_arm() abandoned_arm_2 = get_abandoned_arm() abandoned_arm_2.name = "0_1" self.assertTrue(abandoned_arm < abandoned_arm_2) generator_run = get_generator_run() generator_run_struct = GeneratorRunStruct(generator_run=generator_run, weight=1.0) generator_run_struct_2 = GeneratorRunStruct( generator_run=generator_run, weight=2.0) self.assertTrue(generator_run_struct < generator_run_struct_2)
def trial_from_sqa(self, trial_sqa: SQATrial, experiment: Experiment) -> BaseTrial: """Convert SQLAlchemy Trial to Ax Trial.""" if trial_sqa.is_batch: trial = BatchTrial(experiment=experiment, optimize_for_power=trial_sqa.optimize_for_power) generator_run_structs = [ GeneratorRunStruct( generator_run=self.generator_run_from_sqa( generator_run_sqa=generator_run_sqa), weight=generator_run_sqa.weight or 1.0, ) for generator_run_sqa in trial_sqa.generator_runs ] if trial_sqa.status_quo_name is not None: new_generator_run_structs = [] for struct in generator_run_structs: if (struct.generator_run.generator_run_type == GeneratorRunType.STATUS_QUO.name): status_quo_weight = struct.generator_run.weights[0] trial._status_quo = struct.generator_run.arms[0] trial._status_quo_weight_override = status_quo_weight else: new_generator_run_structs.append(struct) generator_run_structs = new_generator_run_structs trial._generator_run_structs = generator_run_structs trial._abandoned_arms_metadata = { abandoned_arm_sqa.name: self.abandoned_arm_from_sqa( abandoned_arm_sqa=abandoned_arm_sqa) for abandoned_arm_sqa in trial_sqa.abandoned_arms } else: trial = Trial(experiment=experiment) if trial_sqa.generator_runs: if len(trial_sqa.generator_runs) != 1: raise SQADecodeError( # pragma: no cover "Cannot decode SQATrial to Trial because trial is not batched " "but has more than one generator run.") trial._generator_run = self.generator_run_from_sqa( generator_run_sqa=trial_sqa.generator_runs[0]) trial._index = trial_sqa.index trial._trial_type = trial_sqa.trial_type # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = (trial_sqa.status if trial_sqa.status != TrialStatus.DISPATCHED else TrialStatus.RUNNING) trial._time_created = trial_sqa.time_created trial._time_completed = trial_sqa.time_completed trial._time_staged = trial_sqa.time_staged trial._time_run_started = trial_sqa.time_run_started trial._abandoned_reason = trial_sqa.abandoned_reason # pyre-fixme[9]: _run_metadata has type `Dict[str, Any]`; used as # `Optional[Dict[str, Any]]`. trial._run_metadata = (dict(trial_sqa.run_metadata) if trial_sqa.run_metadata is not None else None) trial._num_arms_created = trial_sqa.num_arms_created trial._runner = (self.runner_from_sqa(trial_sqa.runner) if trial_sqa.runner else None) return trial
def trial_from_sqa(self, trial_sqa: SQATrial, experiment: Experiment, reduced_state: bool = False) -> BaseTrial: """Convert SQLAlchemy Trial to Ax Trial. Args: trial_sqa: `SQATrial` to decode. reduced_state: Whether to load trial's generator run(s) with a slightly reduced state (without model state, search space, and optimization config). """ if trial_sqa.is_batch: trial = BatchTrial( experiment=experiment, optimize_for_power=trial_sqa.optimize_for_power, ttl_seconds=trial_sqa.ttl_seconds, index=trial_sqa.index, ) generator_run_structs = [ GeneratorRunStruct( generator_run=self.generator_run_from_sqa( generator_run_sqa=generator_run_sqa, reduced_state=reduced_state, ), weight=generator_run_sqa.weight or 1.0, ) for generator_run_sqa in trial_sqa.generator_runs ] if trial_sqa.status_quo_name is not None: new_generator_run_structs = [] for struct in generator_run_structs: if (struct.generator_run.generator_run_type == GeneratorRunType.STATUS_QUO.name): status_quo_weight = struct.generator_run.weights[0] trial._status_quo = struct.generator_run.arms[0] trial._status_quo_weight_override = status_quo_weight else: new_generator_run_structs.append(struct) generator_run_structs = new_generator_run_structs trial._generator_run_structs = generator_run_structs if not reduced_state: trial._abandoned_arms_metadata = { abandoned_arm_sqa.name: self.abandoned_arm_from_sqa( abandoned_arm_sqa=abandoned_arm_sqa) for abandoned_arm_sqa in trial_sqa.abandoned_arms } trial._refresh_arms_by_name() # Trigger cache build else: trial = Trial( experiment=experiment, ttl_seconds=trial_sqa.ttl_seconds, index=trial_sqa.index, ) if trial_sqa.generator_runs: if len(trial_sqa.generator_runs) != 1: raise SQADecodeError( # pragma: no cover "Cannot decode SQATrial to Trial because trial is not batched " "but has more than one generator run.") trial._generator_run = self.generator_run_from_sqa( generator_run_sqa=trial_sqa.generator_runs[0], reduced_state=reduced_state, ) trial._trial_type = trial_sqa.trial_type # Swap `DISPATCHED` for `RUNNING`, since `DISPATCHED` is deprecated and nearly # equivalent to `RUNNING`. trial._status = (trial_sqa.status if trial_sqa.status != TrialStatus.DISPATCHED else TrialStatus.RUNNING) trial._time_created = trial_sqa.time_created trial._time_completed = trial_sqa.time_completed trial._time_staged = trial_sqa.time_staged trial._time_run_started = trial_sqa.time_run_started trial._abandoned_reason = trial_sqa.abandoned_reason # pyre-fixme[9]: _run_metadata has type `Dict[str, Any]`; used as # `Optional[Dict[str, Any]]`. # pyre-fixme[8]: Attribute has type `Dict[str, typing.Any]`; used as # `Optional[typing.Dict[Variable[_KT], Variable[_VT]]]`. trial._run_metadata = ( # pyre-fixme[6]: Expected `Mapping[Variable[_KT], Variable[_VT]]` for # 1st param but got `Optional[Dict[str, typing.Any]]`. dict(trial_sqa.run_metadata) if trial_sqa.run_metadata is not None else None) trial._num_arms_created = trial_sqa.num_arms_created trial._runner = (self.runner_from_sqa(trial_sqa.runner) if trial_sqa.runner else None) trial._generation_step_index = trial_sqa.generation_step_index trial._properties = trial_sqa.properties or {} trial.db_id = trial_sqa.id return trial