def _preprocess_fn(self, features, labels, mode ): """Flattens inner and sequence dimensions.""" if mode is None: raise ValueError('The mode should never be None.') condition_feature = list(features.condition.features.values())[0] inference_feature = list(features.inference.features.values())[0] # In order to unflatten the flattened examples later, we need to keep # track of the original shapes. num_condition_samples_per_task = ( condition_feature.get_shape().as_list()[1]) num_inference_samples_per_task = ( inference_feature.get_shape().as_list()[1]) if num_condition_samples_per_task is None: raise ValueError('num_condition_samples_per_task cannot be None.') if num_inference_samples_per_task is None: raise ValueError('num_inference_samples_per_task cannot be None.') flat_features = meta_tfdata.flatten_batch_examples(features) flat_labels = None # The original preprocessor can only operate on the flattened data. if labels is not None: flat_labels = meta_tfdata.flatten_batch_examples(labels) # We invoke our original preprocessor on the flat batch. flat_features.condition.features, flat_features.condition.labels = ( self._base_preprocessor.preprocess( features=flat_features.condition.features, labels=flat_features.condition.labels, mode=mode)) (flat_features.inference.features, flat_labels) = self._base_preprocessor.preprocess( features=flat_features.inference.features, labels=flat_labels, mode=mode) # We need to unflatten with num_*_samples_per_task since the preprocessor # might introduce new tensors or reshape existing tensors. features.condition = meta_tfdata.unflatten_batch_examples( flat_features.condition, num_condition_samples_per_task) features.inference = meta_tfdata.unflatten_batch_examples( flat_features.inference, num_inference_samples_per_task) if flat_labels is not None: labels = meta_tfdata.unflatten_batch_examples( flat_labels, num_inference_samples_per_task) return features, labels
def model_eval_fn(self, features, labels, inference_outputs, train_loss, train_outputs, mode, config=None, params=None): """The eval model implementation. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. labels: This is the second item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. inference_outputs: A dict containing the output tensors of model_inference_fn. train_loss: The final loss from model_train_fn. train_outputs: A dict containing the output tensors (dict) of model_train_fn. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will receive what is passed to Estimator in config parameter, or the default config (tf.estimator.RunConfig). Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. Returns: The eval_metrics determined by the base_model. """ inference_output_flat_batch = meta_tfdata.flatten_batch_examples( inference_outputs.full_inference_output) inference_features_flat_batch = meta_tfdata.flatten_batch_examples( features.inference.features) labels_flat_batch = meta_tfdata.flatten_batch_examples(labels) return self._base_model.model_eval_fn( features=inference_features_flat_batch, labels=labels_flat_batch, inference_outputs=inference_output_flat_batch, train_loss=train_loss, train_outputs=train_outputs, mode=mode, config=config, params=params)
def _preprocess_fn(self, features, labels, mode): """See base class.""" if mode is None: raise ValueError('The mode should never be None.') # In order to unflatten the flattened examples later, we need to keep # track of the original shapes. features = meta_tfdata.flatten_batch_examples(features) # The original preprocessor can only operate on the flattened data. if labels is not None: labels = meta_tfdata.flatten_batch_examples(labels) # We invoke our original preprocessor on the flat batch. features.train, labels.train = self._base_preprocessor.preprocess( features=features.train, labels=labels.train, mode=mode) features.val, labels.val = self._base_preprocessor.preprocess( features=features.val, labels=labels.val, mode=mode) else: # We invoke our original preprocessor on the flat batch. features.train, _ = self._base_preprocessor.preprocess( features=features.train, labels=None, mode=mode) features.val, _ = self._base_preprocessor.preprocess( features=features.val, labels=None, mode=mode) # We need to unflatten with num_*_samples_per_task since the preprocessor # might introduce new tensors or reshape existing tensors. features.train = meta_tfdata.unflatten_batch_examples( features.train, self._num_train_samples_per_task) features.val = meta_tfdata.unflatten_batch_examples( features.val, self._num_val_samples_per_task) features.val_mode = tf.reshape(features.val_mode, [-1, 1]) if labels is not None: labels.train = meta_tfdata.unflatten_batch_examples( labels.train, self._num_train_samples_per_task) labels.val = meta_tfdata.unflatten_batch_examples( labels.val, self._num_val_samples_per_task) labels.val_mode = tf.reshape(labels.val_mode, [-1, 1]) return features, labels
def model_train_fn(self, features, labels, inference_outputs, mode, config=None, params=None): """The training model implementation. Args: features: This is the first item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. labels: This is the second item returned from the input_fn and parsed by tensorspec_utils.validate_and_pack. A spec_structure which fulfills the requirements of the self.get_feature_specification. inference_outputs: A dict containing the output tensors of model_inference_fn. mode: (ModeKeys) Specifies if this is training, evaluation or prediction. config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will receive what is passed to Estimator in config parameter, or the default config (tf.estimator.RunConfig). Allows updating things in your model_fn based on configuration such as num_ps_replicas, or model_dir. params: An optional dict of hyper parameters that will be passed into input_fn and model_fn. Keys are names of parameters, values are basic python types. There are reserved keys for TPUEstimator, including 'batch_size'. Returns: The loss and optionally train_outputs of the base model. """ # Since the base model assumes data in the format [batch_size] + shape # and not [num_tasks, num_samples_per_task, shape] we need to flatten # the data prior to calling the model_train_fn. condition_features_flat_batch = meta_tfdata.flatten_batch_examples( features.condition.features) condition_labels_flat_batch = meta_tfdata.flatten_batch_examples( features.condition.labels) for inner_loop_step in range(self._num_inner_loop_steps + 1): condition_output_flat_batch = meta_tfdata.flatten_batch_examples( inference_outputs['full_condition_outputs/output_{}'.format( inner_loop_step)]) with tf.variable_scope( 'inner_loop_step_{}'.format(inner_loop_step)): self._base_model.add_summaries( features=condition_features_flat_batch, labels=condition_labels_flat_batch, inference_outputs=condition_output_flat_batch, train_loss=None, train_outputs=None, mode=mode, params=params) # Since the base model assumes data in the format [batch_size] + shape # and not [num_tasks, num_samples_per_task, shape] we need to flatten # the data prior to calling the model_train_fn. inference_output_flat_batch = meta_tfdata.flatten_batch_examples( inference_outputs.full_inference_output) inference_features_flat_batch = meta_tfdata.flatten_batch_examples( features.inference.features) labels_flat_batch = meta_tfdata.flatten_batch_examples(labels) with tf.variable_scope('unconditioned_inference'): uncondition_output_flat_batch = meta_tfdata.flatten_batch_examples( inference_outputs.full_inference_output_unconditioned) self._base_model.add_summaries( features=condition_features_flat_batch, labels=condition_labels_flat_batch, inference_outputs=uncondition_output_flat_batch, train_loss=None, train_outputs=None, mode=mode, params=params) if params is None: params = {} params['is_outer_loss'] = True return self._base_model.model_train_fn( features=inference_features_flat_batch, labels=labels_flat_batch, inference_outputs=inference_output_flat_batch, config=config, mode=mode, params=params)