Exemplo n.º 1
0
def create_maml_feature_spec(feature_spec, label_spec):
    """Create a meta feature from existing base_model specs.

  Note, the train spec will maintain the same name and thus mapping to the
  input. This is important to create the parse_tf_example_fn automatically.
  The validation data will have a val/ prefix such that we can feed different
  data to both inputs.

  Args:
    feature_spec: A hierarchy of TensorSpecs(subclasses) or Tensors.
    label_spec: A hierarchy of TensorSpecs(subclasses) or Tensors.

  Returns:
    An instance of TensorSpecStruct representing a valid
    meta learning tensor_spec with .condition and .inference access.
  """
    condition_spec = TSpecStructure()
    condition_spec.features = utils.flatten_spec_structure(
        utils.copy_tensorspec(feature_spec,
                              batch_size=-1,
                              prefix='condition_features'))
    condition_spec.labels = utils.flatten_spec_structure(
        utils.copy_tensorspec(label_spec,
                              batch_size=-1,
                              prefix='condition_labels'))
    inference_spec = TSpecStructure()
    inference_spec.features = utils.flatten_spec_structure(
        utils.copy_tensorspec(feature_spec,
                              batch_size=-1,
                              prefix='inference_features'))

    meta_feature_spec = TSpecStructure()
    meta_feature_spec.condition = condition_spec
    meta_feature_spec.inference = inference_spec
    return meta_feature_spec
Exemplo n.º 2
0
def _create_meta_spec(tensor_spec, spec_type, num_train_samples_per_task,
                      num_val_samples_per_task):
    """Create a TrainValPair from an existing spec.

  Note, the train spec will maintain the same name and thus mapping to the
  input. This is important to create the parse_tf_example_fn automatically.
  The validation data will have a val/ prefix such that we can feed different
  data to both inputs.

  Args:
    tensor_spec: A dict, (named)tuple, list or a hierarchy thereof filled by
      TensorSpecs(subclasses) or Tensors.
    spec_type: A string ['features', 'labels'] specifying which spec type we
      alter in order to introduce the corresponding val_mode.
    num_train_samples_per_task: Number of training samples to expect per task
      batch element.
    num_val_samples_per_task: Number of val examples to expect per task batch
      element.

  Raises:
    ValueError: If the spec_type is not in ['features', 'labels'].

  Returns:
    An instance of TensorSpecStruct representing a valid
    meta learning tensor_spec with .train and .val access.
  """
    if spec_type not in ['features', 'labels']:
        raise ValueError('We only support spec_type "features" or "labels" '
                         'but received {}.'.format(spec_type))
    train_tensor_spec = utils.flatten_spec_structure(
        utils.copy_tensorspec(tensor_spec,
                              batch_size=num_train_samples_per_task,
                              prefix='train'))
    # Since the train part is also required for inference, the specs cannot be
    # optional.
    for key, value in train_tensor_spec.items():
        train_tensor_spec[key] = utils.ExtendedTensorSpec.from_spec(
            value, is_optional=False)

    val_tensor_spec = utils.flatten_spec_structure(
        utils.copy_tensorspec(tensor_spec,
                              batch_size=num_val_samples_per_task,
                              prefix='val'))

    # Since the train part is also required for inference, the specs for
    # val cannot be optional because the inputs to a while loop have to be
    # the same for every step of the loop.
    for key, value in val_tensor_spec.items():
        val_tensor_spec[key] = utils.ExtendedTensorSpec.from_spec(
            value, is_optional=False)
    val_mode_shape = (1, )
    if num_train_samples_per_task is None:
        val_mode_shape = ()
    val_mode = TensorSpec(shape=val_mode_shape,
                          dtype=tf.bool,
                          name='val_mode/{}'.format(spec_type))
    return utils.flatten_spec_structure(
        TrainValPair(train=train_tensor_spec,
                     val=val_tensor_spec,
                     val_mode=val_mode))
Exemplo n.º 3
0
 def get_label_specification(self, mode):
   del mode
   action_spec = TensorSpec(
       shape=(self._action_size,), dtype=tf.float32, name='action_world')
   tspec = tensorspec_utils.TensorSpecStruct(action=action_spec)
   return tensorspec_utils.copy_tensorspec(
       tspec, batch_size=self._episode_length)
Exemplo n.º 4
0
 def _episode_label_specification(self, mode):
     """Returns the label spec for a single episode."""
     del mode
     action_spec = TensorSpec(shape=(self._num_waypoints *
                                     self._action_size, ),
                              dtype=tf.float32,
                              name='action_world')
     tspec = tensorspec_utils.TensorSpecStruct(action=action_spec)
     return tensorspec_utils.copy_tensorspec(
         tspec, batch_size=self._episode_length)
Exemplo n.º 5
0
 def test_copy_none_name(self):
   spec = utils.TensorSpecStruct()
   spec.none_name = utils.ExtendedTensorSpec(shape=(1,), dtype=tf.float32)
   spec.with_name = utils.ExtendedTensorSpec(
       shape=(2,), dtype=tf.float32, name='with_name')
   spec_copy = utils.copy_tensorspec(spec, prefix='test')
   # Spec equality does not check the name
   utils.assert_equal(spec, spec_copy)
   self.assertEqual(spec_copy.none_name.name, 'test/')
   self.assertEqual(spec_copy.with_name.name, 'test/with_name')
Exemplo n.º 6
0
 def _episode_feature_specification(self, mode):
     """Returns the feature spec for a single episode."""
     del mode
     full_state_pose_spec = TensorSpec(shape=(self._obs_size, ),
                                       dtype=tf.float32,
                                       name='full_state_pose')
     spec = tensorspec_utils.TensorSpecStruct(
         full_state_pose=full_state_pose_spec)
     spec = tensorspec_utils.copy_tensorspec(
         spec, batch_size=self._episode_length)
     return spec
 def get_in_feature_specification(self, mode):
     """See base class."""
     feature_spec = tensorspec_utils.copy_tensorspec(
         self._model_feature_specification_fn(mode))
     true_img_shape = feature_spec.image.shape.as_list()
     true_img_shape[
         -3:-1] = self._src_img_res  # Overwrite the H, W dimensions.
     feature_spec.image = TensorSpec.from_spec(feature_spec.image,
                                               shape=true_img_shape,
                                               dtype=tf.uint8)
     return tensorspec_utils.flatten_spec_structure(feature_spec)
Exemplo n.º 8
0
def create_maml_label_spec(label_spec):
    """Create a meta feature from existing base_model specs.

  Args:
    label_spec: A hierarchy of TensorSpecs(subclasses) or Tensors.

  Returns:
    An instance of TensorSpecStruct representing a valid
    meta learning tensor_spec for computing the outer loss.
  """
    return utils.flatten_spec_structure(
        utils.copy_tensorspec(label_spec, batch_size=-1, prefix='meta_labels'))
Exemplo n.º 9
0
 def get_feature_specification(self, mode):
   del mode
   image_spec = TensorSpec(
       shape=(100, 100, 3),
       dtype=tf.float32,
       name='image0',
       data_format='jpeg')
   gripper_pose_spec = TensorSpec(
       shape=(14,), dtype=tf.float32, name='world_pose_gripper')
   tspec = tensorspec_utils.TensorSpecStruct(
       image=image_spec, gripper_pose=gripper_pose_spec)
   return tensorspec_utils.copy_tensorspec(
       tspec, batch_size=self._episode_length)
Exemplo n.º 10
0
  def get_in_feature_specification(self, mode
                                  ):
    """See base class."""
    feature_spec = tensorspec_utils.copy_tensorspec(
        self._model_feature_specification_fn(mode))
    # Don't want to parse the original_image, since we don't want to parse it
    # and we are adding this feature in preprocess_fn to satisfy the model's
    # inputs.
    if mode != PREDICT and 'original_image' in feature_spec:
      del feature_spec['original_image']

    if 'image' in feature_spec:
      true_img_shape = feature_spec.image.shape.as_list()
      # Overwrite the H, W dimensions.
      true_img_shape[-3:-1] = self._src_img_res
      feature_spec.image = TensorSpec.from_spec(
          feature_spec.image, shape=true_img_shape, dtype=tf.uint8)
    return tensorspec_utils.flatten_spec_structure(feature_spec)
Exemplo n.º 11
0
 def test_copy(self, collection_type):
     spec = self._make_tensorspec_collection(collection_type)
     spec_copy = utils.copy_tensorspec(spec)
     utils.assert_equal(spec, spec_copy)