def Do(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]) -> None: if tfx_tuner.get_tune_args(exec_properties): raise ValueError( "TuneArgs is not supported by this Tuner's Executor.") metalearning_algorithm = None if 'metalearning_algorithm' in exec_properties: metalearning_algorithm = exec_properties.get( 'metalearning_algorithm') warmup_trials = 0 warmup_trial_data = None if metalearning_algorithm: warmup_tuner, warmup_trials = self.warmup(input_dict, exec_properties, metalearning_algorithm) warmup_trial_data = extract_tuner_trial_progress(warmup_tuner) else: logging.info('MetaLearning Algorithm not provided.') # Create new fn_args for final tuning stage. fn_args = fn_args_utils.get_common_fn_args( input_dict, exec_properties, working_dir=self._get_tmp_dir()) tuner_fn = udf_utils.get_fn(exec_properties, 'tuner_fn') tuner_fn_result = tuner_fn(fn_args) tuner_fn_result.tuner.oracle.max_trials = max( (tuner_fn_result.tuner.oracle.max_trials - warmup_trials), 1) tuner = self.search(tuner_fn_result) tuner_trial_data = extract_tuner_trial_progress(tuner) if warmup_trial_data: cumulative_tuner_trial_data, best_tuner_ix = merge_trial_data( warmup_trial_data, tuner_trial_data) cumulative_tuner_trial_data[ 'warmup_trial_data'] = warmup_trial_data[BEST_CUMULATIVE_SCORE] cumulative_tuner_trial_data['tuner_trial_data'] = tuner_trial_data[ BEST_CUMULATIVE_SCORE] if isinstance(tuner.oracle.objective, kerastuner.Objective): cumulative_tuner_trial_data[ 'objective'] = tuner.oracle.objective.name else: cumulative_tuner_trial_data[ 'objective'] = 'objective not understood' tuner_trial_data = cumulative_tuner_trial_data best_tuner = warmup_tuner if best_tuner_ix == 0 else tuner else: best_tuner = tuner tfx_tuner.write_best_hyperparameters(best_tuner, output_dict) tuner_plot_path = os.path.join( artifact_utils.get_single_uri(output_dict['trial_summary_plot']), 'tuner_plot_data.txt') io_utils.write_string_file(tuner_plot_path, json.dumps(tuner_trial_data)) logging.info('Tuner plot data written at: %s', tuner_plot_path)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: tuner = self._search(input_dict, exec_properties) if self._tuner_id is not None and not self._is_chief: logging.info('Returning since this is not chief worker.') return tuner_executor.write_best_hyperparameters(tuner, output_dict) self._close()
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: tuner = self._search(input_dict, exec_properties) if not self._is_chief: logging.info('Returning since this is not chief worker.') tuner_executor.write_best_hyperparameters(tuner, output_dict) if self._chief_process and self._chief_process.is_alive(): logging.info('Terminating chief oracle at PID: %s', self._chief_process.pid) self._chief_process.terminate()