def _create_average_ensemble_logits(self, ensemble_logits, search_logits_specs): """Bundles together averaged logits.""" logits = tf.add_n(ensemble_logits) / len(ensemble_logits) ensemble_logits_spec = architecture_utils.LogitsSpec(logits=logits) if (trial_utils.is_nonadaptive_ensemble_search(self._ensemble_spec) or trial_utils.is_intermixed_ensemble_search( self._ensemble_spec)): return EnsembleLogits(train_logits_specs=[], eval_logits_spec=ensemble_logits_spec) if trial_utils.is_adaptive_ensemble_search(self._ensemble_spec): return EnsembleLogits(train_logits_specs=search_logits_specs, eval_logits_spec=ensemble_logits_spec) if trial_utils.is_residual_ensemble_search(self._ensemble_spec): return EnsembleLogits(train_logits_specs=[ensemble_logits_spec], eval_logits_spec=ensemble_logits_spec)
def first_time_chief_generate(self, features, input_layer_fn, trial_mode, shared_input_tensor, shared_lengths, logits_dimension, hparams, run_config, is_training, trials): """Creates the prior for the ensemble.""" my_id = architecture_utils.DirectoryHandler.get_trial_id( run_config.model_dir, self._phoenix_spec) prior_build_args = dict(features=features, input_layer_fn=input_layer_fn, shared_input_tensor=shared_input_tensor, shared_lengths=shared_lengths, is_training=is_training, trials=trials, logits_dimension=logits_dimension, my_id=my_id, my_model_dir=run_config.model_dir) if trial_mode == trial_utils.TrialMode.DISTILLATION: return self.build_priors_distillation(**prior_build_args) if trial_utils.is_nonadaptive_ensemble_search( self._phoenix_spec.ensemble_spec): return self.build_priors_nonadaptively(**prior_build_args) if trial_utils.is_adaptive_ensemble_search( self._phoenix_spec.ensemble_spec): return self.build_priors_adaptively(**prior_build_args) if trial_utils.is_residual_ensemble_search( self._phoenix_spec.ensemble_spec): return self.build_priors_adaptively(**prior_build_args) if trial_utils.is_intermixed_ensemble_search( self._phoenix_spec.ensemble_spec): return self.build_priors_intermixed(**prior_build_args) # No ensemble spec or distillation spec was specified. architecture_utils.set_number_of_towers(self.generator_name(), 0) return [], []
def _create_weighted_ensemble_logits(self, ensemble_logits, search_logits_specs, logits_dimension): """Bundles together weighted logits.""" logits = tf.keras.layers.Dense(units=logits_dimension)(tf.concat( ensemble_logits, axis=-1)) ensemble_logits_spec = architecture_utils.LogitsSpec(logits=logits) if (trial_utils.is_nonadaptive_ensemble_search(self._ensemble_spec) or trial_utils.is_intermixed_ensemble_search( self._ensemble_spec)): return EnsembleLogits(train_logits_specs=[ensemble_logits_spec], eval_logits_spec=ensemble_logits_spec) if trial_utils.is_adaptive_ensemble_search(self._ensemble_spec): return EnsembleLogits(train_logits_specs=[ensemble_logits_spec] + search_logits_specs, eval_logits_spec=ensemble_logits_spec) if trial_utils.is_residual_ensemble_search(self._ensemble_spec): return EnsembleLogits(train_logits_specs=[ensemble_logits_spec], eval_logits_spec=ensemble_logits_spec)
def get_generators(self, my_id, all_trials): """Determines which generators to run.""" output = {} ensemble_spec = self._phoenix_spec.ensemble_spec distillation_spec = self._phoenix_spec.distillation_spec logging.info("trial id: %d", my_id) # Handling replay if self._replay_state.is_replay(): if self._replay_state.replay_is_training_a_tower(my_id): output.update({ base_tower_generator.SEARCH_GENERATOR: GeneratorWithTrials(self._search_candidate_generator, []) }) if self._replay_state.replay_is_importing_towers(my_id): output.update({ base_tower_generator.REPLAY_GENERATOR: GeneratorWithTrials(self._replay_generator, []) }) return _return_generators(output) # Real Search from here on. # First: User suggestions first! No ensembling in suggestions. if my_id <= len(self._phoenix_spec.user_suggestions): logging.info("user suggestions mode") output.update({ base_tower_generator.SEARCH_GENERATOR: GeneratorWithTrials(self._search_candidate_generator, []) }) return _return_generators(output) # Second: Handle non-adaptive search if trial_utils.is_nonadaptive_ensemble_search(ensemble_spec): logging.info("non adaptive ensembling mode") pool_size = ensemble_spec.nonadaptive_search.minimal_pool_size search_trials = [t for t in all_trials if t.id <= pool_size] # Pool too small, continue searching if my_id <= pool_size: output.update({ base_tower_generator.SEARCH_GENERATOR: GeneratorWithTrials(self._search_candidate_generator, search_trials) }) return _return_generators(output) # Pool hit critical mass, start ensembling. else: output.update({ base_tower_generator.PRIOR_GENERATOR: GeneratorWithTrials(self._prior_candidate_generator, search_trials) }) return _return_generators(output) # Third: Adaptive / Residual ensemble search if (trial_utils.is_adaptive_ensemble_search(ensemble_spec) or trial_utils.is_residual_ensemble_search(ensemble_spec)): logging.info("adaptive/residual ensembling mode") increase_every = ensemble_spec.adaptive_search.increase_width_every pool_size = my_id // increase_every * increase_every ensembling_trials = [ trial for trial in all_trials if trial.id <= pool_size ] search_trials = [ trial for trial in all_trials if trial.id > pool_size ] if ensembling_trials: output.update({ base_tower_generator.SEARCH_GENERATOR: GeneratorWithTrials(self._search_candidate_generator, search_trials), base_tower_generator.PRIOR_GENERATOR: GeneratorWithTrials(self._prior_candidate_generator, ensembling_trials) }) return _return_generators(output) else: output.update({ base_tower_generator.SEARCH_GENERATOR: GeneratorWithTrials(self._search_candidate_generator, search_trials) }) return _return_generators(output) # Fourth: Intermixed Search. if trial_utils.is_intermixed_ensemble_search(ensemble_spec): logging.info("intermix ensemble search mode") n = ensemble_spec.intermixed_search.try_ensembling_every search_trials = [t for t in all_trials if t.id % n != 0] if my_id % n != 0: output.update({ base_tower_generator.SEARCH_GENERATOR: GeneratorWithTrials(self._search_candidate_generator, search_trials) }) if (trial_utils.get_trial_mode( ensemble_spec, distillation_spec, my_id) == trial_utils.TrialMode.DISTILLATION): output.update({ base_tower_generator.PRIOR_GENERATOR: GeneratorWithTrials(self._prior_candidate_generator, all_trials) }) return _return_generators(output) else: output.update({ base_tower_generator.PRIOR_GENERATOR: GeneratorWithTrials(self._prior_candidate_generator, search_trials) }) return _return_generators(output) # No ensembling output.update({ base_tower_generator.SEARCH_GENERATOR: GeneratorWithTrials(self._search_candidate_generator, all_trials) }) return _return_generators(output)
def first_time_chief_generate(self, features, input_layer_fn, trial_mode, shared_input_tensor, shared_lengths, logits_dimension, hparams, run_config, is_training, trials): dropout_rate = getattr(hparams, "dropout_rate", None) my_id = architecture_utils.DirectoryHandler.get_trial_id( run_config.model_dir, self._phoenix_spec) create_new_architecture_fn = functools.partial( self._create_new_architecture, features=features, input_layer_fn=input_layer_fn, shared_input_tensor=shared_input_tensor, run_config=run_config, my_id=my_id, hparams=hparams, is_training=is_training, shared_lengths=shared_lengths, logits_dimension=logits_dimension, dropout_rate=dropout_rate, trials=trials) # First, try out user suggestions. if my_id <= len(self._phoenix_spec.user_suggestions): return create_new_architecture_fn( architecture=self._get_user_suggestion(my_id), prev_trial=-1) if trial_mode == trial_utils.TrialMode.ENSEMBLE_SEARCH: # Non-adaptive ensemble search. if trial_utils.is_nonadaptive_ensemble_search(self._ensemble_spec): # Done searching if we've hit critical mass. architecture_utils.set_number_of_towers(self.generator_name(), 0) return [], [] # Adaptive and residual ensemble search. elif (trial_utils.is_adaptive_ensemble_search(self._ensemble_spec) or trial_utils.is_residual_ensemble_search(self._ensemble_spec)): every = self._ensemble_spec.adaptive_search.increase_width_every relevant_trials = trials if every: relevant_trials = [ trial for trial in trials if trial.id >= my_id // every * every ] architecture, prev_trial = self._search_algorithm.get_suggestion( relevant_trials, hparams, my_id, run_config.model_dir) return create_new_architecture_fn( architecture=architecture, prev_trial=prev_trial) # Intermixed ensemble search. elif trial_utils.is_intermixed_ensemble_search(self._ensemble_spec): every = self._ensemble_spec.intermixed_search.try_ensembling_every # Do not search if this is a non-exploration trial. if my_id % every == 0: architecture_utils.set_number_of_towers(self.generator_name(), 0) return [], [] # Search if this is an exploration trial. relevant_trials = [trial for trial in trials if trial.id % every != 0] architecture, prev_trial = self._search_algorithm.get_suggestion( relevant_trials, hparams, my_id, run_config.model_dir) return create_new_architecture_fn( architecture=architecture, prev_trial=prev_trial) else: raise ValueError("Unknown ensemble search type '{}'".format( self._ensemble_spec.ensemble_search_type)) if (trial_mode == trial_utils.TrialMode.DISTILLATION and trial_utils.is_intermixed_ensemble_search(self._ensemble_spec)): relevant_trials = trial_utils.get_intermixed_trials( trials, self._ensemble_spec.intermixed_search.try_ensembling_every, len(self._phoenix_spec.user_suggestions)) best_trial = self._metadata.get_best_k(trials=relevant_trials, k=1) if best_trial is not None: model_dir = architecture_utils.DirectoryHandler.trial_dir(best_trial) assert architecture_utils.get_number_of_towers( model_dir, self.generator_name()) == 1 tower_name = self.generator_name() + "_0" tower_spec = architecture_utils.import_tower( phoenix_spec=self._phoenix_spec, features=features, input_layer_fn=input_layer_fn, shared_input_tensor=shared_input_tensor, original_tower_name=tower_name, new_tower_name=tower_name, model_directory=model_dir, new_model_directory=run_config.model_dir, is_training=is_training, logits_dimension=logits_dimension, shared_lengths=shared_lengths, force_snapshot=False, force_freeze=False, allow_auxiliary_head=self._allow_auxiliary_head) architecture_utils.set_number_of_towers(self.generator_name(), 1) return [tower_spec.logits_spec], [tower_spec.architecture] # If no ensembling search method is specified, or this is a distillation # trial without intermixed ensemble_search, get a new tower based on the # architecture search algorithm. # This will serve as the student model if distillation occurs on this trial. architecture, prev_trial = self._search_algorithm.get_suggestion( trials, hparams, my_id, run_config.model_dir) return create_new_architecture_fn( architecture=architecture, prev_trial=prev_trial)
def __init__(self, phoenix_spec, input_layer_fn, study_owner, study_name, head=None, logits_dimension=None, label_vocabulary=None, loss_fn=None, metric_fn=None, predictions_fn=None, metadata=None): """Constructs a Phoenix instance. Args: phoenix_spec: A `PhoenixSpec` proto with the spec for the run. input_layer_fn: A function that converts feature Tensors to input layer. See learning.autolx.model_search.data.Provider.get_input_layer_fn for details. study_owner: A string holding the ldap of the study owner. We use tuner platforms to conduct the various architectures training. This field specifies the study owner. study_name: A string holding the study name. head: A head to use with Phoenix for creating the loss and eval metrics. If no head is given, Phoenix falls back to using the loss_fn and metric_fn. N.B.: Phoenix creates its own EstimatorSpec so everything besides the loss and eval metrics returned by head will be ignored. logits_dimension: An int holding the dimension of the output. Must be provided if head is None. Will be ignored if head is not None. label_vocabulary: List or tuple with labels vocabulary. Needed only if the labels are of type string. This list is used by the loss function if loss_fn is not provided. It is also used in the metric function to create the accuracy metric ops. Use only with multiclass classification problems. loss_fn: A function to compute the loss. Ignored if `head` is not None. Must accept as inputs a `labels` Tensor, a `logits` Tensor, and optionally a `weights` Tensor. `weights` must either be rank 0 or have the same rank as labels. If None, Phoenix defaults to using softmax cross-entropy. metric_fn: Metrics for Tensorboard. Ignored if `head` is not None. metric_fn takes `label` and `predictions` as input, and outputs a dictionary of (tensor, update_op) tuples. `label` is a Tensor (in the single task case) or a dict of Tensors (in the case of multi-task, where the key of the dicts correspond to the task names). `predictions` is a dict of Tensors. In the single task case, it consists of `predictions`, `probabilities`, and `log_probabilities`. In the multi-task case, it consists of the same keys as that of the single task case, but also those corresponding to each task (e.g., predictions/task_name_1). See `metric_fns` for more detail. If `metric_fn` is None, it will include a metric for the number of parameters, accuracy (if logit_dimensions >= 2), and AUC metrics (if logit_dimensions == 2). predictions_fn: A function to convert eval logits to the `predictions` dictionary passed to metric_fn. If `None`, defaults to computing 'predictions', 'probabilities', and 'log_probabilities'. metadata: An object that implements metadata api in learning.adanets.phoenix.metadata.Metadata """ # Check Phoenix preconditions and fail early if any of them are broken. if phoenix_spec.multi_task_spec: # TODO(b/172564129): Add support for head and custom loss_fns in # multi-task. assert not head, "head is not supported for multi-task." if head: msg = "Do not specify {} when using head as head already contains it." assert not logits_dimension, msg.format("logits_dimension") assert not label_vocabulary, msg.format("label_vocabulary") assert not loss_fn, msg.format("loss_fn") assert not metric_fn, msg.format("metric_fn") # Check ensemble search / distillation preconditions. ensemble_spec = phoenix_spec.ensemble_spec distillation_spec = phoenix_spec.distillation_spec if trial_utils.has_distillation( distillation_spec) and trial_utils.has_ensemble_search( ensemble_spec ) and not trial_utils.is_intermixed_ensemble_search(ensemble_spec): ensemble_search_spec = ( ensemble_spec.nonadaptive_search if trial_utils.is_nonadaptive_ensemble_search(ensemble_spec) else ensemble_spec.adaptive_search) if (distillation_spec.minimal_pool_size == ensemble_search_spec.minimal_pool_size): logging.warning("minimal_pool_size is the same for ensemble spec and " "distillation spec, so distillation will be ignored.") self._phoenix_spec = phoenix_spec self._input_layer_fn = input_layer_fn self._ensembler = ensembler.Ensembler(phoenix_spec) self._distiller = distillation.Distiller(phoenix_spec.distillation_spec) self._study_owner = study_owner self._study_name = study_name self._head = head self._logits_dimension = ( self._head.logits_dimension if head else logits_dimension) self._label_vocabulary = label_vocabulary if self._label_vocabulary: assert self._logits_dimension == len(self._label_vocabulary) self._loss_fn = loss_fn or loss_fns.make_multi_class_loss_fn( label_vocabulary=label_vocabulary) self._user_specified_metric_fn = metric_fn self._predictions_fn = (predictions_fn or _default_predictions_fn) if metadata is None: self._metadata = ml_metadata_db.MLMetaData(phoenix_spec, study_name, study_owner) else: self._metadata = metadata self._task_manager = task_manager.TaskManager(phoenix_spec) self._controller = controller.InProcessController( phoenix_spec=phoenix_spec, metadata=self._metadata)