Ejemplo n.º 1
0
  def _preprocess_fn(self, features,
                     labels,
                     mode
                    ):
    """Flattens inner and sequence dimensions."""
    if mode is None:
      raise ValueError('The mode should never be None.')

    condition_feature = list(features.condition.features.values())[0]
    inference_feature = list(features.inference.features.values())[0]
    # In order to unflatten the flattened examples later, we need to keep
    # track of the original shapes.
    num_condition_samples_per_task = (
        condition_feature.get_shape().as_list()[1])
    num_inference_samples_per_task = (
        inference_feature.get_shape().as_list()[1])

    if num_condition_samples_per_task is None:
      raise ValueError('num_condition_samples_per_task cannot be None.')
    if num_inference_samples_per_task is None:
      raise ValueError('num_inference_samples_per_task cannot be None.')

    flat_features = meta_tfdata.flatten_batch_examples(features)

    flat_labels = None
    # The original preprocessor can only operate on the flattened data.
    if labels is not None:
      flat_labels = meta_tfdata.flatten_batch_examples(labels)

    # We invoke our original preprocessor on the flat batch.
    flat_features.condition.features, flat_features.condition.labels = (
        self._base_preprocessor.preprocess(
            features=flat_features.condition.features,
            labels=flat_features.condition.labels,
            mode=mode))
    (flat_features.inference.features,
     flat_labels) = self._base_preprocessor.preprocess(
         features=flat_features.inference.features,
         labels=flat_labels,
         mode=mode)

    # We need to unflatten with num_*_samples_per_task since the preprocessor
    # might introduce new tensors or reshape existing tensors.
    features.condition = meta_tfdata.unflatten_batch_examples(
        flat_features.condition, num_condition_samples_per_task)
    features.inference = meta_tfdata.unflatten_batch_examples(
        flat_features.inference, num_inference_samples_per_task)
    if flat_labels is not None:
      labels = meta_tfdata.unflatten_batch_examples(
          flat_labels, num_inference_samples_per_task)

    return features, labels
Ejemplo n.º 2
0
    def model_eval_fn(self,
                      features,
                      labels,
                      inference_outputs,
                      train_loss,
                      train_outputs,
                      mode,
                      config=None,
                      params=None):
        """The eval model implementation.

    Args:
      features: This is the first item returned from the input_fn and parsed
        by tensorspec_utils.validate_and_pack. A spec_structure which fulfills
        the requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed
        by tensorspec_utils.validate_and_pack. A spec_structure which fulfills
        the requirements of the self.get_feature_specification.
      inference_outputs: A dict containing the output tensors of
        model_inference_fn.
      train_loss: The final loss from model_train_fn.
      train_outputs: A dict containing the output tensors (dict) of
        model_train_fn.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig)
        Will receive what is passed to Estimator in config parameter, or the
        default config (tf.estimator.RunConfig). Allows updating things in your
        model_fn based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator,
        including 'batch_size'.
    Returns:
      The eval_metrics determined by the base_model.
    """
        inference_output_flat_batch = meta_tfdata.flatten_batch_examples(
            inference_outputs.full_inference_output)
        inference_features_flat_batch = meta_tfdata.flatten_batch_examples(
            features.inference.features)
        labels_flat_batch = meta_tfdata.flatten_batch_examples(labels)
        return self._base_model.model_eval_fn(
            features=inference_features_flat_batch,
            labels=labels_flat_batch,
            inference_outputs=inference_output_flat_batch,
            train_loss=train_loss,
            train_outputs=train_outputs,
            mode=mode,
            config=config,
            params=params)
Ejemplo n.º 3
0
  def _preprocess_fn(self, features,
                     labels,
                     mode):
    """See base class."""
    if mode is None:
      raise ValueError('The mode should never be None.')

    # In order to unflatten the flattened examples later, we need to keep
    # track of the original shapes.
    features = meta_tfdata.flatten_batch_examples(features)

    # The original preprocessor can only operate on the flattened data.
    if labels is not None:
      labels = meta_tfdata.flatten_batch_examples(labels)

      # We invoke our original preprocessor on the flat batch.
      features.train, labels.train = self._base_preprocessor.preprocess(
          features=features.train, labels=labels.train, mode=mode)
      features.val, labels.val = self._base_preprocessor.preprocess(
          features=features.val, labels=labels.val, mode=mode)
    else:
      # We invoke our original preprocessor on the flat batch.
      features.train, _ = self._base_preprocessor.preprocess(
          features=features.train, labels=None, mode=mode)
      features.val, _ = self._base_preprocessor.preprocess(
          features=features.val, labels=None, mode=mode)

    # We need to unflatten with num_*_samples_per_task since the preprocessor
    # might introduce new tensors or reshape existing tensors.
    features.train = meta_tfdata.unflatten_batch_examples(
        features.train, self._num_train_samples_per_task)
    features.val = meta_tfdata.unflatten_batch_examples(
        features.val, self._num_val_samples_per_task)
    features.val_mode = tf.reshape(features.val_mode, [-1, 1])
    if labels is not None:
      labels.train = meta_tfdata.unflatten_batch_examples(
          labels.train, self._num_train_samples_per_task)
      labels.val = meta_tfdata.unflatten_batch_examples(
          labels.val, self._num_val_samples_per_task)
      labels.val_mode = tf.reshape(labels.val_mode, [-1, 1])

    return features, labels
Ejemplo n.º 4
0
    def model_train_fn(self,
                       features,
                       labels,
                       inference_outputs,
                       mode,
                       config=None,
                       params=None):
        """The training model implementation.

    Args:
      features: This is the first item returned from the input_fn and parsed
        by tensorspec_utils.validate_and_pack. A spec_structure which fulfills
        the requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed
        by tensorspec_utils.validate_and_pack. A spec_structure which fulfills
        the requirements of the self.get_feature_specification.
      inference_outputs: A dict containing the output tensors of
        model_inference_fn.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig)
        Will receive what is passed to Estimator in config parameter, or the
        default config (tf.estimator.RunConfig). Allows updating things in your
        model_fn based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator,
        including 'batch_size'.
    Returns:
      The loss and optionally train_outputs of the base model.
    """
        # Since the base model assumes data in the format [batch_size] + shape
        # and not [num_tasks, num_samples_per_task, shape] we need to flatten
        # the data prior to calling the model_train_fn.
        condition_features_flat_batch = meta_tfdata.flatten_batch_examples(
            features.condition.features)
        condition_labels_flat_batch = meta_tfdata.flatten_batch_examples(
            features.condition.labels)
        for inner_loop_step in range(self._num_inner_loop_steps + 1):
            condition_output_flat_batch = meta_tfdata.flatten_batch_examples(
                inference_outputs['full_condition_outputs/output_{}'.format(
                    inner_loop_step)])
            with tf.variable_scope(
                    'inner_loop_step_{}'.format(inner_loop_step)):
                self._base_model.add_summaries(
                    features=condition_features_flat_batch,
                    labels=condition_labels_flat_batch,
                    inference_outputs=condition_output_flat_batch,
                    train_loss=None,
                    train_outputs=None,
                    mode=mode,
                    params=params)

        # Since the base model assumes data in the format [batch_size] + shape
        # and not [num_tasks, num_samples_per_task, shape] we need to flatten
        # the data prior to calling the model_train_fn.
        inference_output_flat_batch = meta_tfdata.flatten_batch_examples(
            inference_outputs.full_inference_output)
        inference_features_flat_batch = meta_tfdata.flatten_batch_examples(
            features.inference.features)
        labels_flat_batch = meta_tfdata.flatten_batch_examples(labels)

        with tf.variable_scope('unconditioned_inference'):
            uncondition_output_flat_batch = meta_tfdata.flatten_batch_examples(
                inference_outputs.full_inference_output_unconditioned)
            self._base_model.add_summaries(
                features=condition_features_flat_batch,
                labels=condition_labels_flat_batch,
                inference_outputs=uncondition_output_flat_batch,
                train_loss=None,
                train_outputs=None,
                mode=mode,
                params=params)

        if params is None:
            params = {}
        params['is_outer_loss'] = True
        return self._base_model.model_train_fn(
            features=inference_features_flat_batch,
            labels=labels_flat_batch,
            inference_outputs=inference_output_flat_batch,
            config=config,
            mode=mode,
            params=params)