예제 #1
0
    def Do(self, input_dict: Dict[str, List[types.Artifact]],
           output_dict: Dict[str, List[types.Artifact]],
           exec_properties: Dict[str, Any]) -> None:

        if tfx_tuner.get_tune_args(exec_properties):
            raise ValueError(
                "TuneArgs is not supported by this Tuner's Executor.")

        metalearning_algorithm = None
        if 'metalearning_algorithm' in exec_properties:
            metalearning_algorithm = exec_properties.get(
                'metalearning_algorithm')

        warmup_trials = 0
        warmup_trial_data = None
        if metalearning_algorithm:
            warmup_tuner, warmup_trials = self.warmup(input_dict,
                                                      exec_properties,
                                                      metalearning_algorithm)
            warmup_trial_data = extract_tuner_trial_progress(warmup_tuner)
        else:
            logging.info('MetaLearning Algorithm not provided.')

        # Create new fn_args for final tuning stage.
        fn_args = fn_args_utils.get_common_fn_args(
            input_dict, exec_properties, working_dir=self._get_tmp_dir())
        tuner_fn = udf_utils.get_fn(exec_properties, 'tuner_fn')
        tuner_fn_result = tuner_fn(fn_args)
        tuner_fn_result.tuner.oracle.max_trials = max(
            (tuner_fn_result.tuner.oracle.max_trials - warmup_trials), 1)
        tuner = self.search(tuner_fn_result)
        tuner_trial_data = extract_tuner_trial_progress(tuner)

        if warmup_trial_data:
            cumulative_tuner_trial_data, best_tuner_ix = merge_trial_data(
                warmup_trial_data, tuner_trial_data)
            cumulative_tuner_trial_data[
                'warmup_trial_data'] = warmup_trial_data[BEST_CUMULATIVE_SCORE]
            cumulative_tuner_trial_data['tuner_trial_data'] = tuner_trial_data[
                BEST_CUMULATIVE_SCORE]

            if isinstance(tuner.oracle.objective, kerastuner.Objective):
                cumulative_tuner_trial_data[
                    'objective'] = tuner.oracle.objective.name
            else:
                cumulative_tuner_trial_data[
                    'objective'] = 'objective not understood'

            tuner_trial_data = cumulative_tuner_trial_data
            best_tuner = warmup_tuner if best_tuner_ix == 0 else tuner
        else:
            best_tuner = tuner
        tfx_tuner.write_best_hyperparameters(best_tuner, output_dict)
        tuner_plot_path = os.path.join(
            artifact_utils.get_single_uri(output_dict['trial_summary_plot']),
            'tuner_plot_data.txt')
        io_utils.write_string_file(tuner_plot_path,
                                   json.dumps(tuner_trial_data))
        logging.info('Tuner plot data written at: %s', tuner_plot_path)
예제 #2
0
파일: executor.py 프로젝트: jasonz1112/tfx
  def Do(self, input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:

    tuner = self._search(input_dict, exec_properties)

    if self._tuner_id is not None and not self._is_chief:
      logging.info('Returning since this is not chief worker.')
      return

    tuner_executor.write_best_hyperparameters(tuner, output_dict)

    self._close()
예제 #3
0
파일: executor.py 프로젝트: gbaned/tfx
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:

        tuner = self._search(input_dict, exec_properties)

        if not self._is_chief:
            logging.info('Returning since this is not chief worker.')

        tuner_executor.write_best_hyperparameters(tuner, output_dict)

        if self._chief_process and self._chief_process.is_alive():
            logging.info('Terminating chief oracle at PID: %s',
                         self._chief_process.pid)
            self._chief_process.terminate()