コード例 #1
0
ファイル: ensembler.py プロジェクト: zeng8280/model_search
    def _create_average_ensemble_logits(self, ensemble_logits,
                                        search_logits_specs):
        """Bundles together averaged logits."""

        logits = tf.add_n(ensemble_logits) / len(ensemble_logits)
        ensemble_logits_spec = architecture_utils.LogitsSpec(logits=logits)

        if (trial_utils.is_nonadaptive_ensemble_search(self._ensemble_spec)
                or trial_utils.is_intermixed_ensemble_search(
                    self._ensemble_spec)):
            return EnsembleLogits(train_logits_specs=[],
                                  eval_logits_spec=ensemble_logits_spec)

        if trial_utils.is_adaptive_ensemble_search(self._ensemble_spec):
            return EnsembleLogits(train_logits_specs=search_logits_specs,
                                  eval_logits_spec=ensemble_logits_spec)

        if trial_utils.is_residual_ensemble_search(self._ensemble_spec):
            return EnsembleLogits(train_logits_specs=[ensemble_logits_spec],
                                  eval_logits_spec=ensemble_logits_spec)
コード例 #2
0
    def first_time_chief_generate(self, features, input_layer_fn, trial_mode,
                                  shared_input_tensor, shared_lengths,
                                  logits_dimension, hparams, run_config,
                                  is_training, trials):
        """Creates the prior for the ensemble."""
        my_id = architecture_utils.DirectoryHandler.get_trial_id(
            run_config.model_dir, self._phoenix_spec)

        prior_build_args = dict(features=features,
                                input_layer_fn=input_layer_fn,
                                shared_input_tensor=shared_input_tensor,
                                shared_lengths=shared_lengths,
                                is_training=is_training,
                                trials=trials,
                                logits_dimension=logits_dimension,
                                my_id=my_id,
                                my_model_dir=run_config.model_dir)

        if trial_mode == trial_utils.TrialMode.DISTILLATION:
            return self.build_priors_distillation(**prior_build_args)

        if trial_utils.is_nonadaptive_ensemble_search(
                self._phoenix_spec.ensemble_spec):
            return self.build_priors_nonadaptively(**prior_build_args)

        if trial_utils.is_adaptive_ensemble_search(
                self._phoenix_spec.ensemble_spec):
            return self.build_priors_adaptively(**prior_build_args)

        if trial_utils.is_residual_ensemble_search(
                self._phoenix_spec.ensemble_spec):
            return self.build_priors_adaptively(**prior_build_args)

        if trial_utils.is_intermixed_ensemble_search(
                self._phoenix_spec.ensemble_spec):
            return self.build_priors_intermixed(**prior_build_args)

        # No ensemble spec or distillation spec was specified.
        architecture_utils.set_number_of_towers(self.generator_name(), 0)
        return [], []
コード例 #3
0
ファイル: ensembler.py プロジェクト: zeng8280/model_search
    def _create_weighted_ensemble_logits(self, ensemble_logits,
                                         search_logits_specs,
                                         logits_dimension):
        """Bundles together weighted logits."""

        logits = tf.keras.layers.Dense(units=logits_dimension)(tf.concat(
            ensemble_logits, axis=-1))
        ensemble_logits_spec = architecture_utils.LogitsSpec(logits=logits)

        if (trial_utils.is_nonadaptive_ensemble_search(self._ensemble_spec)
                or trial_utils.is_intermixed_ensemble_search(
                    self._ensemble_spec)):
            return EnsembleLogits(train_logits_specs=[ensemble_logits_spec],
                                  eval_logits_spec=ensemble_logits_spec)

        if trial_utils.is_adaptive_ensemble_search(self._ensemble_spec):
            return EnsembleLogits(train_logits_specs=[ensemble_logits_spec] +
                                  search_logits_specs,
                                  eval_logits_spec=ensemble_logits_spec)

        if trial_utils.is_residual_ensemble_search(self._ensemble_spec):
            return EnsembleLogits(train_logits_specs=[ensemble_logits_spec],
                                  eval_logits_spec=ensemble_logits_spec)
コード例 #4
0
    def get_generators(self, my_id, all_trials):
        """Determines which generators to run."""
        output = {}
        ensemble_spec = self._phoenix_spec.ensemble_spec
        distillation_spec = self._phoenix_spec.distillation_spec
        logging.info("trial id: %d", my_id)

        # Handling replay
        if self._replay_state.is_replay():
            if self._replay_state.replay_is_training_a_tower(my_id):
                output.update({
                    base_tower_generator.SEARCH_GENERATOR:
                    GeneratorWithTrials(self._search_candidate_generator, [])
                })
            if self._replay_state.replay_is_importing_towers(my_id):
                output.update({
                    base_tower_generator.REPLAY_GENERATOR:
                    GeneratorWithTrials(self._replay_generator, [])
                })

            return _return_generators(output)

        # Real Search from here on.
        # First: User suggestions first! No ensembling in suggestions.
        if my_id <= len(self._phoenix_spec.user_suggestions):
            logging.info("user suggestions mode")
            output.update({
                base_tower_generator.SEARCH_GENERATOR:
                GeneratorWithTrials(self._search_candidate_generator, [])
            })
            return _return_generators(output)

        # Second: Handle non-adaptive search
        if trial_utils.is_nonadaptive_ensemble_search(ensemble_spec):
            logging.info("non adaptive ensembling mode")
            pool_size = ensemble_spec.nonadaptive_search.minimal_pool_size
            search_trials = [t for t in all_trials if t.id <= pool_size]
            # Pool too small, continue searching
            if my_id <= pool_size:
                output.update({
                    base_tower_generator.SEARCH_GENERATOR:
                    GeneratorWithTrials(self._search_candidate_generator,
                                        search_trials)
                })
                return _return_generators(output)
            # Pool hit critical mass, start ensembling.
            else:
                output.update({
                    base_tower_generator.PRIOR_GENERATOR:
                    GeneratorWithTrials(self._prior_candidate_generator,
                                        search_trials)
                })
                return _return_generators(output)

        # Third: Adaptive / Residual ensemble search
        if (trial_utils.is_adaptive_ensemble_search(ensemble_spec)
                or trial_utils.is_residual_ensemble_search(ensemble_spec)):
            logging.info("adaptive/residual ensembling mode")
            increase_every = ensemble_spec.adaptive_search.increase_width_every
            pool_size = my_id // increase_every * increase_every
            ensembling_trials = [
                trial for trial in all_trials if trial.id <= pool_size
            ]
            search_trials = [
                trial for trial in all_trials if trial.id > pool_size
            ]
            if ensembling_trials:
                output.update({
                    base_tower_generator.SEARCH_GENERATOR:
                    GeneratorWithTrials(self._search_candidate_generator,
                                        search_trials),
                    base_tower_generator.PRIOR_GENERATOR:
                    GeneratorWithTrials(self._prior_candidate_generator,
                                        ensembling_trials)
                })
                return _return_generators(output)
            else:
                output.update({
                    base_tower_generator.SEARCH_GENERATOR:
                    GeneratorWithTrials(self._search_candidate_generator,
                                        search_trials)
                })
                return _return_generators(output)

        # Fourth: Intermixed Search.
        if trial_utils.is_intermixed_ensemble_search(ensemble_spec):
            logging.info("intermix ensemble search mode")
            n = ensemble_spec.intermixed_search.try_ensembling_every
            search_trials = [t for t in all_trials if t.id % n != 0]
            if my_id % n != 0:
                output.update({
                    base_tower_generator.SEARCH_GENERATOR:
                    GeneratorWithTrials(self._search_candidate_generator,
                                        search_trials)
                })
                if (trial_utils.get_trial_mode(
                        ensemble_spec, distillation_spec,
                        my_id) == trial_utils.TrialMode.DISTILLATION):
                    output.update({
                        base_tower_generator.PRIOR_GENERATOR:
                        GeneratorWithTrials(self._prior_candidate_generator,
                                            all_trials)
                    })
                return _return_generators(output)
            else:
                output.update({
                    base_tower_generator.PRIOR_GENERATOR:
                    GeneratorWithTrials(self._prior_candidate_generator,
                                        search_trials)
                })
                return _return_generators(output)

        # No ensembling
        output.update({
            base_tower_generator.SEARCH_GENERATOR:
            GeneratorWithTrials(self._search_candidate_generator, all_trials)
        })
        return _return_generators(output)
コード例 #5
0
  def first_time_chief_generate(self, features, input_layer_fn, trial_mode,
                                shared_input_tensor, shared_lengths,
                                logits_dimension, hparams, run_config,
                                is_training, trials):
    dropout_rate = getattr(hparams, "dropout_rate", None)
    my_id = architecture_utils.DirectoryHandler.get_trial_id(
        run_config.model_dir, self._phoenix_spec)
    create_new_architecture_fn = functools.partial(
        self._create_new_architecture,
        features=features,
        input_layer_fn=input_layer_fn,
        shared_input_tensor=shared_input_tensor,
        run_config=run_config,
        my_id=my_id,
        hparams=hparams,
        is_training=is_training,
        shared_lengths=shared_lengths,
        logits_dimension=logits_dimension,
        dropout_rate=dropout_rate,
        trials=trials)

    # First, try out user suggestions.
    if my_id <= len(self._phoenix_spec.user_suggestions):
      return create_new_architecture_fn(
          architecture=self._get_user_suggestion(my_id), prev_trial=-1)

    if trial_mode == trial_utils.TrialMode.ENSEMBLE_SEARCH:

      # Non-adaptive ensemble search.
      if trial_utils.is_nonadaptive_ensemble_search(self._ensemble_spec):
        # Done searching if we've hit critical mass.
        architecture_utils.set_number_of_towers(self.generator_name(), 0)
        return [], []

      # Adaptive and residual ensemble search.
      elif (trial_utils.is_adaptive_ensemble_search(self._ensemble_spec) or
            trial_utils.is_residual_ensemble_search(self._ensemble_spec)):
        every = self._ensemble_spec.adaptive_search.increase_width_every
        relevant_trials = trials
        if every:
          relevant_trials = [
              trial for trial in trials if trial.id >= my_id // every * every
          ]
        architecture, prev_trial = self._search_algorithm.get_suggestion(
            relevant_trials, hparams, my_id, run_config.model_dir)
        return create_new_architecture_fn(
            architecture=architecture, prev_trial=prev_trial)

      # Intermixed ensemble search.
      elif trial_utils.is_intermixed_ensemble_search(self._ensemble_spec):
        every = self._ensemble_spec.intermixed_search.try_ensembling_every

        # Do not search if this is a non-exploration trial.
        if my_id % every == 0:
          architecture_utils.set_number_of_towers(self.generator_name(), 0)
          return [], []

        # Search if this is an exploration trial.
        relevant_trials = [trial for trial in trials if trial.id % every != 0]
        architecture, prev_trial = self._search_algorithm.get_suggestion(
            relevant_trials, hparams, my_id, run_config.model_dir)
        return create_new_architecture_fn(
            architecture=architecture, prev_trial=prev_trial)

      else:
        raise ValueError("Unknown ensemble search type '{}'".format(
            self._ensemble_spec.ensemble_search_type))

    if (trial_mode == trial_utils.TrialMode.DISTILLATION and
        trial_utils.is_intermixed_ensemble_search(self._ensemble_spec)):
      relevant_trials = trial_utils.get_intermixed_trials(
          trials, self._ensemble_spec.intermixed_search.try_ensembling_every,
          len(self._phoenix_spec.user_suggestions))
      best_trial = self._metadata.get_best_k(trials=relevant_trials, k=1)
      if best_trial is not None:
        model_dir = architecture_utils.DirectoryHandler.trial_dir(best_trial)
        assert architecture_utils.get_number_of_towers(
            model_dir, self.generator_name()) == 1
        tower_name = self.generator_name() + "_0"
        tower_spec = architecture_utils.import_tower(
            phoenix_spec=self._phoenix_spec,
            features=features,
            input_layer_fn=input_layer_fn,
            shared_input_tensor=shared_input_tensor,
            original_tower_name=tower_name,
            new_tower_name=tower_name,
            model_directory=model_dir,
            new_model_directory=run_config.model_dir,
            is_training=is_training,
            logits_dimension=logits_dimension,
            shared_lengths=shared_lengths,
            force_snapshot=False,
            force_freeze=False,
            allow_auxiliary_head=self._allow_auxiliary_head)
        architecture_utils.set_number_of_towers(self.generator_name(), 1)
        return [tower_spec.logits_spec], [tower_spec.architecture]

    # If no ensembling search method is specified, or this is a distillation
    # trial without intermixed ensemble_search, get a new tower based on the
    # architecture search algorithm.
    # This will serve as the student model if distillation occurs on this trial.
    architecture, prev_trial = self._search_algorithm.get_suggestion(
        trials, hparams, my_id, run_config.model_dir)
    return create_new_architecture_fn(
        architecture=architecture, prev_trial=prev_trial)
コード例 #6
0
ファイル: phoenix.py プロジェクト: alex7772942/TRPP
  def __init__(self,
               phoenix_spec,
               input_layer_fn,
               study_owner,
               study_name,
               head=None,
               logits_dimension=None,
               label_vocabulary=None,
               loss_fn=None,
               metric_fn=None,
               predictions_fn=None,
               metadata=None):
    """Constructs a Phoenix instance.

    Args:
      phoenix_spec: A `PhoenixSpec` proto with the spec for the run.
      input_layer_fn: A function that converts feature Tensors to input layer.
        See learning.autolx.model_search.data.Provider.get_input_layer_fn
        for details.
      study_owner: A string holding the ldap of the study owner. We use tuner
        platforms to conduct the various architectures training. This field
        specifies the study owner.
      study_name: A string holding the study name.
      head: A head to use with Phoenix for creating the loss and eval metrics.
        If no head is given, Phoenix falls back to using the loss_fn and
        metric_fn. N.B.: Phoenix creates its own EstimatorSpec so everything
          besides the loss and eval metrics returned by head will be ignored.
      logits_dimension: An int holding the dimension of the output. Must be
        provided if head is None. Will be ignored if head is not None.
      label_vocabulary: List or tuple with labels vocabulary. Needed only if the
        labels are of type string. This list is used by the loss function if
        loss_fn is not provided. It is also used in the metric function to
        create the accuracy metric ops. Use only with multiclass classification
        problems.
      loss_fn: A function to compute the loss. Ignored if `head` is not None.
        Must accept as inputs a `labels` Tensor, a `logits` Tensor, and
        optionally a `weights` Tensor. `weights` must either be rank 0 or have
        the same rank as labels. If None, Phoenix defaults to using softmax
        cross-entropy.
      metric_fn: Metrics for Tensorboard. Ignored if `head` is not None.
        metric_fn takes `label` and `predictions` as input, and outputs a
        dictionary of (tensor, update_op) tuples. `label` is a Tensor (in the
        single task case) or a dict of Tensors (in the case of multi-task, where
        the key of the dicts correspond to the task names). `predictions` is a
        dict of Tensors. In the single task case, it consists of `predictions`,
        `probabilities`, and `log_probabilities`. In the multi-task case, it
        consists of the same keys as that of the single task case, but also
        those corresponding to each task (e.g., predictions/task_name_1). See
        `metric_fns` for more detail. If `metric_fn` is None, it will include a
        metric for the number of parameters, accuracy (if logit_dimensions >=
        2), and AUC metrics (if logit_dimensions == 2).
      predictions_fn: A function to convert eval logits to the
        `predictions` dictionary passed to metric_fn. If `None`, defaults to
        computing 'predictions', 'probabilities', and 'log_probabilities'.
      metadata: An object that implements metadata api in
        learning.adanets.phoenix.metadata.Metadata
    """

    # Check Phoenix preconditions and fail early if any of them are broken.
    if phoenix_spec.multi_task_spec:
      # TODO(b/172564129): Add support for head and custom loss_fns in
      # multi-task.
      assert not head, "head is not supported for multi-task."
    if head:
      msg = "Do not specify {} when using head as head already contains it."
      assert not logits_dimension, msg.format("logits_dimension")
      assert not label_vocabulary, msg.format("label_vocabulary")
      assert not loss_fn, msg.format("loss_fn")
      assert not metric_fn, msg.format("metric_fn")

    # Check ensemble search / distillation preconditions.
    ensemble_spec = phoenix_spec.ensemble_spec
    distillation_spec = phoenix_spec.distillation_spec
    if trial_utils.has_distillation(
        distillation_spec) and trial_utils.has_ensemble_search(
            ensemble_spec
        ) and not trial_utils.is_intermixed_ensemble_search(ensemble_spec):
      ensemble_search_spec = (
          ensemble_spec.nonadaptive_search
          if trial_utils.is_nonadaptive_ensemble_search(ensemble_spec) else
          ensemble_spec.adaptive_search)
      if (distillation_spec.minimal_pool_size ==
          ensemble_search_spec.minimal_pool_size):
        logging.warning("minimal_pool_size is the same for ensemble spec and "
                        "distillation spec, so distillation will be ignored.")

    self._phoenix_spec = phoenix_spec
    self._input_layer_fn = input_layer_fn
    self._ensembler = ensembler.Ensembler(phoenix_spec)
    self._distiller = distillation.Distiller(phoenix_spec.distillation_spec)
    self._study_owner = study_owner
    self._study_name = study_name
    self._head = head
    self._logits_dimension = (
        self._head.logits_dimension if head else logits_dimension)
    self._label_vocabulary = label_vocabulary
    if self._label_vocabulary:
      assert self._logits_dimension == len(self._label_vocabulary)

    self._loss_fn = loss_fn or loss_fns.make_multi_class_loss_fn(
        label_vocabulary=label_vocabulary)

    self._user_specified_metric_fn = metric_fn

    self._predictions_fn = (predictions_fn or _default_predictions_fn)

    if metadata is None:
      self._metadata = ml_metadata_db.MLMetaData(phoenix_spec, study_name,
                                                 study_owner)
    else:
      self._metadata = metadata
    self._task_manager = task_manager.TaskManager(phoenix_spec)
    self._controller = controller.InProcessController(
        phoenix_spec=phoenix_spec, metadata=self._metadata)