예제 #1
0
    def pack_features(self, state, prev_episode_data, timestep):
        """Combines current state and conditioning data into MetaExample spec.

    See create_metaexample_spec for an example of the spec layout.

    If prev_episode_data does not contain enough episodes to fill
      num_condition_samples_per_task, we stuff dummy episodes with reward=0.5
      so that no inner gradients are applied.

    Args:
      state: VRGripperObservation containing image and pose.
      prev_episode_data: A list of episode data, each of which is a list of
        tuples containing transition data. Each transition tuple takes the form
        (obs, action, rew, new_obs, done, debug).
      timestep: Current episode timestep.
    Returns:
      TensorSpecStruct containing conditioning (features, labels)
        and inference (features) keys.
    Raises:
      ValueError: If no demonstration is provided.
    """
        meta_features = tensorspec_utils.TensorSpecStruct()
        meta_features['inference/features/state/0'] = state

        def pack_condition_features(episode_data, idx, dummy_values=False):
            """Pack previous episode data into condition_ep* features/labels.

      Args:
        episode_data: List of (obs, action, rew, new_obs, done, debug) tuples.
        idx: Index of the conditioning episode. 0 for demo, 1 for first trial,
          etc.
        dummy_values: If an episode is not available yet, set the loss_mask
          to 0.

      """
            transition = episode_data[0]
            meta_features['condition/features/state/%d' % idx] = transition[0]
            reward = np.array([transition[2]])
            reward = 2 * reward - 1
            if dummy_values:
                # success_weight of 0. = no gradients in inner loop for this batch.
                reward = np.array([0.])
            meta_features['condition/labels/target_pose/%d' %
                          idx] = transition[1]
            meta_features['condition/labels/reward/%d' % idx] = reward.astype(
                np.float32)

        if prev_episode_data:
            pack_condition_features(prev_episode_data[0], 0)
        else:
            dummy_labels = self._make_dummy_labels()
            dummy_episode = [(state, dummy_labels.target_pose,
                              dummy_labels.reward)]
            pack_condition_features(dummy_episode, 0, dummy_values=True)
        return nest.map_structure(lambda x: np.expand_dims(x, 0),
                                  meta_features)
예제 #2
0
  def test_tensor_spec_struct_assignment(self):
    test_flat_ordered_dict = utils.TensorSpecStruct()

    # We cannot assign an empty ordered dict.
    with self.assertRaises(ValueError):
      test_flat_ordered_dict.should_raise = (
          utils.TensorSpecStruct())

    # Invalid data types for assignment
    # TODO(T2R_CONTRIBUTORS): Deterimine which type is not supported by pytype.
    # for should_raise in ['1', 1, 1.0, {}]:
    #   with self.assertRaises(ValueError):
    #     test_flat_ordered_dict.should_raise = should_raise

    sub_data = utils.TensorSpecStruct()
    sub_data.data = np.ones(1)
    test_flat_ordered_dict.sub_data = sub_data
    # Now we can also extend.
    test_flat_ordered_dict.sub_data.additional = np.zeros(1)
예제 #3
0
 def _make_dummy_labels(self):
     """Helper function to make dummy labels for pack_labels."""
     label_spec = self._base_model.get_label_specification(
         tf.estimator.ModeKeys.TRAIN)
     reward_shape = tuple(label_spec.reward.shape)
     pose_shape = tuple(label_spec.target_pose.shape)
     dummy_reward = np.zeros(reward_shape).astype(np.float32)
     dummy_pose = np.zeros(pose_shape).astype(np.float32)
     return tensorspec_utils.TensorSpecStruct(reward=dummy_reward,
                                              target_pose=dummy_pose)
예제 #4
0
 def get_in_feature_specification(self, mode):
     """See base class."""
     feature_spec = tensorspec_utils.TensorSpecStruct()
     feature_spec['state'] = TensorSpec(
         shape=self._model_feature_specification_fn(mode).state.shape,
         dtype=tf.uint8,
         name=self._model_feature_specification_fn(mode).state.name,
         data_format=self._model_feature_specification_fn(
             mode).state.data_format)
     return feature_spec
예제 #5
0
  def test_tensor_spec_struct_init(self):
    flat_ordered_dict_with_attributes = utils.TensorSpecStruct(
        REFERENCE_FLAT_ORDERED_DICT)
    self.assertDictEqual(flat_ordered_dict_with_attributes,
                         REFERENCE_FLAT_ORDERED_DICT)

    # Ensure we see the right subset of the data.
    self.assertEqual(
        list(flat_ordered_dict_with_attributes.train.keys()),
        ['images', 'actions'])
예제 #6
0
 def test_copy_none_name(self):
   spec = utils.TensorSpecStruct()
   spec.none_name = utils.ExtendedTensorSpec(shape=(1,), dtype=tf.float32)
   spec.with_name = utils.ExtendedTensorSpec(
       shape=(2,), dtype=tf.float32, name='with_name')
   spec_copy = utils.copy_tensorspec(spec, prefix='test')
   # Spec equality does not check the name
   utils.assert_equal(spec, spec_copy)
   self.assertEqual(spec_copy.none_name.name, 'test/')
   self.assertEqual(spec_copy.with_name.name, 'test/with_name')
예제 #7
0
 def _get_label_specification(self):
     spec = tensorspec_utils.TensorSpecStruct()
     spec.target = tensorspec_utils.ExtendedTensorSpec(name='target',
                                                       shape=(1, ),
                                                       dtype=tf.float32)
     spec.proxy = tensorspec_utils.ExtendedTensorSpec(name='proxy',
                                                      shape=(1, ),
                                                      dtype=tf.float32,
                                                      is_optional=True)
     return spec
예제 #8
0
 def _episode_label_specification(self, mode):
     """Returns the label spec for a single episode."""
     del mode
     action_spec = TensorSpec(shape=(self._num_waypoints *
                                     self._action_size, ),
                              dtype=tf.float32,
                              name='action_world')
     tspec = tensorspec_utils.TensorSpecStruct(action=action_spec)
     return tensorspec_utils.copy_tensorspec(
         tspec, batch_size=self._episode_length)
예제 #9
0
 def _get_feature_specification(self):
     spec = tensorspec_utils.TensorSpecStruct()
     spec.action = tensorspec_utils.ExtendedTensorSpec(name='action',
                                                       shape=(1, ),
                                                       dtype=tf.float32)
     spec.velocity = tensorspec_utils.ExtendedTensorSpec(name='velocity',
                                                         shape=(1, ),
                                                         dtype=tf.float32,
                                                         is_optional=True)
     return spec
예제 #10
0
  def test_compress_decompress_fn(self):
    batch_size = 5
    base_dir = 'tensor2robot'
    file_pattern = os.path.join(FLAGS.test_srcdir, base_dir,
                                'test_data/pose_env_test_data.tfrecord')
    dataset = tfdata.parallel_read(file_patterns=file_pattern)
    state_spec = TSPEC(
        shape=(64, 64, 3),
        dtype=tf.uint8,
        name='state/image',
        data_format='jpeg')
    action_spec = TSPEC(shape=(2), dtype=tf.bfloat16, name='pose')
    reward_spec = TSPEC(shape=(), dtype=tf.float32, name='reward')
    feature_spec = tensorspec_utils.TensorSpecStruct(
        state=state_spec, action=action_spec)
    label_spec = tensorspec_utils.TensorSpecStruct(reward=reward_spec)

    dataset = tfdata.parallel_read(file_patterns=file_pattern)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = tfdata.serialized_to_parsed(dataset, feature_spec, label_spec)
    features, _ = dataset.make_one_shot_iterator().get_next()
    # Check tensor shapes.
    self.assertAllEqual((batch_size,) + feature_spec.state.shape,
                        features.state.get_shape().as_list())
    with self.session() as session:
      original_features = session.run(features)

    dataset = tfdata.parallel_read(file_patterns=file_pattern)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = tfdata.serialized_to_parsed(dataset, feature_spec, label_spec)
    dataset = dataset.map(
        tfdata.create_compress_fn(feature_spec, label_spec, quality=100))
    dataset = dataset.map(tfdata.create_decompress_fn(feature_spec, label_spec))
    features, _ = dataset.make_one_shot_iterator().get_next()
    # Check tensor shapes.
    self.assertAllEqual((batch_size,) + feature_spec.state.shape,
                        features.state.get_shape().as_list())
    with self.session() as session:
      compressed_decompressed_features = session.run(features)
    ref_state = original_features.state.astype(np.float32) / 255
    state = compressed_decompressed_features.state.astype(np.float32) / 255
    np.testing.assert_almost_equal(ref_state, state, decimal=1)
예제 #11
0
 def test_tensor_spec_struct_attribut_errors(self):
     flat_ordered_dict_with_attributes = utils.TensorSpecStruct(
         REFERENCE_FLAT_ORDERED_DICT)
     # These attributes do not exist.
     with self.assertRaises(AttributeError):
         _ = flat_ordered_dict_with_attributes['optional_typo']
     with self.assertRaises(AttributeError):
         _ = flat_ordered_dict_with_attributes.optional_typo
     self.assertDictEqual(
         flat_ordered_dict_with_attributes['optional'].to_dict(),
         flat_ordered_dict_with_attributes.optional.to_dict())
예제 #12
0
 def _episode_feature_specification(self, mode):
     """Returns the feature spec for a single episode."""
     del mode
     full_state_pose_spec = TensorSpec(shape=(self._obs_size, ),
                                       dtype=tf.float32,
                                       name='full_state_pose')
     spec = tensorspec_utils.TensorSpecStruct(
         full_state_pose=full_state_pose_spec)
     spec = tensorspec_utils.copy_tensorspec(
         spec, batch_size=self._episode_length)
     return spec
예제 #13
0
    def pack_state_to_feature_spec(self, state_params):
        """Packs the state feature spec from the state.

    Args:
      state_params: Instance of state_spec_class.

    Returns:
      feature_spec: An instance of self.feature_spec_class. This contains
        features for the state.
    """
        feature_spec = tensorspec_utils.TensorSpecStruct(state=state_params)
        return feature_spec
예제 #14
0
    def get_feature_specification(self, mode):
        """Gets model inputs (including context) for inference.

    Arguments:
      mode: The mode for feature specifications

    Returns:
      feature_spec: A named tuple with fields for the state.
    """
        del mode
        return tensorspec_utils.TensorSpecStruct(
            state=self.state_specification)
예제 #15
0
 def test_init_with_attributes(self):
     train = utils.TensorSpecStruct(images=T1, actions=T2)
     flat_nested_optional_spec = utils.flatten_spec_structure(
         mock_nested_optional_spec)
     utils.assert_equal(train, flat_nested_optional_spec.train)
     alternative_dict = {'o6': O6, 'o4': O4}
     hierarchy = utils.TensorSpecStruct(
         nested_optional_spec=mock_nested_optional_spec,
         alternative=alternative_dict)
     utils.assert_equal(hierarchy.nested_optional_spec,
                        flat_nested_optional_spec)
     self.assertDictEqual(hierarchy.alternative.to_dict(), alternative_dict)
     self.assertCountEqual(list(hierarchy.alternative.keys()), ['o4', 'o6'])
     self.assertCountEqual(list(hierarchy.keys()), [
         'nested_optional_spec/train/images',
         'nested_optional_spec/train/actions',
         'nested_optional_spec/test/images',
         'nested_optional_spec/test/actions',
         'nested_optional_spec/optional/images',
         'nested_optional_spec/optional/actions', 'alternative/o6',
         'alternative/o4'
     ])
예제 #16
0
 def get_feature_specification(
     self, mode):
   tspec = tensorspec_utils.TensorSpecStruct()
   tspec.pregrasp_image = TensorSpec(
       shape=self._scene_size + (3,), dtype=tf.float32, name='image',
       data_format='jpeg')
   tspec.postgrasp_image = TensorSpec(
       shape=self._scene_size + (3,), dtype=tf.float32, name='postgrasp_image',
       data_format='jpeg')
   tspec.goal_image = TensorSpec(
       shape=self._goal_size + (3,), dtype=tf.float32, name='present_image',
       data_format='jpeg')
   return tspec
예제 #17
0
    def pack_state_action_to_feature_spec(self, state_params, action_params):
        """Gets a feature spec namedtuple from the state and action.

    Args:
      state_params: Instance of state_spec_class.
      action_params: Instance of action_spec_class.

    Returns:
      feature_spec: An instance of self.feature_spec_class. This contains
        features for both the action and state.
    """
        return tensorspec_utils.TensorSpecStruct(state=state_params,
                                                 action=action_params)
예제 #18
0
 def get_feature_specification(self, mode):
   del mode
   image_spec = TensorSpec(
       shape=(100, 100, 3),
       dtype=tf.float32,
       name='image0',
       data_format='jpeg')
   gripper_pose_spec = TensorSpec(
       shape=(14,), dtype=tf.float32, name='world_pose_gripper')
   tspec = tensorspec_utils.TensorSpecStruct(
       image=image_spec, gripper_pose=gripper_pose_spec)
   return tensorspec_utils.copy_tensorspec(
       tspec, batch_size=self._episode_length)
예제 #19
0
  def get_label_specification(
      self, mode):

    """See base class documentation."""
    del mode
    spec_structure = tensorspec_utils.TensorSpecStruct()
    if self._multi_dataset:
      spec_structure.y = tensorspec_utils.ExtendedTensorSpec(
          shape=(1,), dtype=tf.float32, name='valid_position',
          dataset_key='dataset1')
    else:
      spec_structure.y = tensorspec_utils.ExtendedTensorSpec(
          shape=(1,), dtype=tf.float32, name='valid_position')
    return spec_structure
예제 #20
0
 def test_pack_flat_sequence_to_spec_structure_ensure_order(self):
   test_spec = utils.TensorSpecStruct()
   test_spec.b = utils.ExtendedTensorSpec(
       shape=(1,), dtype=tf.float32, name='b')
   test_spec.a = utils.ExtendedTensorSpec(
       shape=(1,), dtype=tf.float32, name='a')
   test_spec.c = utils.ExtendedTensorSpec(
       shape=(1,), dtype=tf.float32, name='c')
   placeholders = utils.make_placeholders(test_spec)
   packed_placeholders = utils.pack_flat_sequence_to_spec_structure(
       test_spec, placeholders)
   for pos, order_name in enumerate(['a', 'b', 'c']):
     self.assertEqual(list(packed_placeholders.keys())[pos], order_name)
     self.assertEqual(
         list(packed_placeholders.values())[pos].op.name, order_name)
예제 #21
0
    def test_tensor_spec_struct_deleting_element(self):
        flat_ordered_dict_with_attributes = utils.TensorSpecStruct(
            REFERENCE_FLAT_ORDERED_DICT)
        # Now we show that we can delete items and it will propagate down.
        self.assertEqual(flat_ordered_dict_with_attributes.optional.images, O4)

        del flat_ordered_dict_with_attributes['optional/images']
        with self.assertRaises(AttributeError):
            _ = flat_ordered_dict_with_attributes.optional.images

        # Now we show that we can delete items and it will propagate up.
        self.assertIn('test/actions', flat_ordered_dict_with_attributes)
        test = flat_ordered_dict_with_attributes.test
        del test['actions']
        self.assertNotIn('test/actions', flat_ordered_dict_with_attributes)
예제 #22
0
    def _filter_using_spec(self, tensor_spec_struct, output_spec):
        """Filters all optional tensors from the tensor_spec_struct.

    Args:
      tensor_spec_struct: The instance of TensorSpecStruct which contains the
        preprocessing tensors.
      output_spec: The reference TensorSpecStruct which allows to infer which
        tensors should be removed.

    Returns:
      A new instance which contains only required tensors.
    """
        filtered_spec_struct = tensorspec_utils.TensorSpecStruct()
        for key in output_spec.keys():
            filtered_spec_struct[key] = tensor_spec_struct[key]
        return filtered_spec_struct
예제 #23
0
 def get_feature_specification(
     self, mode):
   """See base class documentation."""
   del mode
   spec_structure = tensorspec_utils.TensorSpecStruct()
   if self._multi_dataset:
     spec_structure.x1 = tensorspec_utils.ExtendedTensorSpec(
         shape=(3,), dtype=tf.float32, name='measured_position',
         dataset_key='dataset1')
     spec_structure.x2 = tensorspec_utils.ExtendedTensorSpec(
         shape=(3,), dtype=tf.float32, name='measured_position',
         dataset_key='dataset2')
   else:
     spec_structure.x = tensorspec_utils.ExtendedTensorSpec(
         shape=(3,), dtype=tf.float32, name='measured_position')
   return spec_structure
예제 #24
0
    def test_tensor_spec_struct_adding_attribute(self):
        flat_ordered_dict_with_attributes = utils.TensorSpecStruct(
            REFERENCE_FLAT_ORDERED_DICT)
        # Now we check that we can change an attribute and it affects the parent.
        flat_ordered_dict_with_attributes.train.addition = O6
        self.assertEqual(list(flat_ordered_dict_with_attributes.train.keys()),
                         ['images', 'actions', 'addition'])

        self.assertEqual(flat_ordered_dict_with_attributes.train.addition,
                         flat_ordered_dict_with_attributes['train/addition'])

        # It propagates to the parent.
        self.assertEqual(
            list(flat_ordered_dict_with_attributes.keys()),
            list(REFERENCE_FLAT_ORDERED_DICT.keys()) + ['train/addition'])
        self.assertEqual(flat_ordered_dict_with_attributes['train/addition'],
                         flat_ordered_dict_with_attributes.train.addition)
예제 #25
0
 def test_varlen_feature_spec(self, batch_size):
   file_pattern = os.path.join(self.create_tempdir().full_path,
                               'test.tfrecord')
   test_data = [[1], [1, 2]]
   self._write_test_varlen_examples(test_data, file_pattern)
   feature_spec = tensorspec_utils.TensorSpecStruct()
   feature_spec.varlen = tensorspec_utils.ExtendedTensorSpec(
       shape=(3,), dtype=tf.int64, name='varlen', varlen_default_value=3.0)
   dataset = tfdata.parallel_read(file_patterns=file_pattern)
   dataset = dataset.batch(batch_size, drop_remainder=True)
   dataset = tfdata.serialized_to_parsed(dataset, feature_spec, None)
   features = dataset.make_one_shot_iterator().get_next()
   # Check tensor shapes.
   self.assertAllEqual([None, 3], features.varlen.get_shape().as_list())
   with self.session() as session:
     np_features = session.run(features)
     self.assertAllEqual(np_features.varlen,
                         np.array([[1, 3, 3], [1, 2, 3]][:batch_size]))
     self.assertAllEqual([batch_size, 3], np_features.varlen.shape)
예제 #26
0
 def test_varlen_images_feature_spec(self, batch_size):
   file_pattern = os.path.join(self.create_tempdir().full_path,
                               'test.tfrecord')
   image_width = 640
   image_height = 512
   padded_varlen_size = 3
   maxval = 255  # Maximum value for byte-encoded image.
   image_np = np.random.uniform(
       size=(image_height, image_width), high=maxval).astype(np.int32)
   png_encoded_image = image.numpy_to_image_string(image_np, 'png')
   test_data = [[png_encoded_image], [png_encoded_image, png_encoded_image]]
   self._write_test_varlen_images_examples(test_data, file_pattern)
   feature_spec = tensorspec_utils.TensorSpecStruct()
   feature_spec.varlen_images = tensorspec_utils.ExtendedTensorSpec(
       shape=(padded_varlen_size, image_height, image_width, 1),
       dtype=tf.uint8,
       name='varlen_images',
       data_format='png',
       varlen_default_value=0)
   dataset = tfdata.parallel_read(file_patterns=file_pattern)
   dataset = dataset.batch(batch_size, drop_remainder=True)
   dataset = tfdata.serialized_to_parsed(dataset, feature_spec, None)
   features = dataset.make_one_shot_iterator().get_next()
   # Check tensor shapes.
   self.assertAllEqual(
       [None, padded_varlen_size, image_height, image_width, 1],
       features.varlen_images.get_shape().as_list())
   with self.session() as session:
     np_features = session.run(features)
     black_image = np.zeros((image_height, image_width))
     self.assertAllEqual(
         np_features.varlen_images,
         np.expand_dims(
             np.stack([
                 np.stack([image_np, black_image, black_image]),
                 np.stack([image_np, image_np, black_image])
             ])[:batch_size, :, :, :],
             axis=-1))
     self.assertAllEqual(
         [batch_size, padded_varlen_size, image_height, image_width, 1],
         np_features.varlen_images.shape)
예제 #27
0
 def test_images_decoding_raises(self):
   file_pattern = os.path.join(self.create_tempdir().full_path,
                               'test.tfrecord')
   image_width = 640
   image_height = 512
   maxval = np.iinfo(np.uint32).max  # Maximum value for byte-encoded image.
   image_np = np.random.uniform(
       size=(image_height, image_width), high=maxval).astype(np.int32)
   png_encoded_image = image.numpy_to_image_string(image_np, 'png',
                                                   np.uint32)
   test_data = [[png_encoded_image]]
   self._write_test_images_examples(test_data, file_pattern)
   feature_spec = tensorspec_utils.TensorSpecStruct()
   feature_spec.images = tensorspec_utils.ExtendedTensorSpec(
       shape=(image_height, image_width, 1),
       dtype=tf.uint32,
       name='image/encoded',
       data_format='png')
   dataset = tfdata.parallel_read(file_patterns=file_pattern)
   dataset = dataset.batch(1, drop_remainder=True)
   with self.assertRaises(ValueError):
     tfdata.serialized_to_parsed(dataset, feature_spec, None)
예제 #28
0
    def get_action_specification(self):
        close_gripper_spec = tensorspec_utils.ExtendedTensorSpec(
            shape=(1, ), dtype=tf.float32, name='close_gripper')
        open_gripper_spec = tensorspec_utils.ExtendedTensorSpec(
            shape=(1, ), dtype=tf.float32, name='open_gripper')
        terminate_episode_spec = tensorspec_utils.ExtendedTensorSpec(
            shape=(1, ), dtype=tf.float32, name='terminate_episode')
        gripper_closed_spec = tensorspec_utils.ExtendedTensorSpec(
            shape=(1, ), dtype=tf.float32, name='gripper_closed')
        world_vector_spec = tensorspec_utils.ExtendedTensorSpec(
            shape=(3), dtype=tf.float32, name='world_vector')
        vertical_rotation_spec = tensorspec_utils.ExtendedTensorSpec(
            shape=(2), dtype=tf.float32, name='vertical_rotation')
        height_to_bottom_spec = tensorspec_utils.ExtendedTensorSpec(
            shape=(1, ), dtype=tf.float32, name='height_to_bottom')

        return tensorspec_utils.TensorSpecStruct(
            world_vector=world_vector_spec,
            vertical_rotation=vertical_rotation_spec,
            close_gripper=close_gripper_spec,
            open_gripper=open_gripper_spec,
            terminate_episode=terminate_episode_spec,
            gripper_closed=gripper_closed_spec,
            height_to_bottom=height_to_bottom_spec)
예제 #29
0
  def parse_tf_example_fn(*input_values):
    """Maps string tensors (serialized TFExamples) to parsed tensors.

    Args:
      *input_values: A (string tensor,) tuple if mapping from a RecordIODataset
        or TFRecordDataset, or a (key, string tensor) tuple if mapping from a
        SSTableDataset, or (Dict[dataset_key, values],) if mapping from multiple
        datasets.

    Returns:
      features: Collection of tensors conforming to feature_tspec.
      labels: Collection of tensors conforming to label_tspec.
    Raises:
        ValueError: If dtype other than uint8 or uint16 is supplied for image
        specs.
    """
    dict_extracted = _get_sstable_proto_dict(*input_values)

    def parse_wrapper(example, spec_dict):
      """Wrap tf.parse_example to support bfloat16 dtypes.

      This allows models which declare bfloat16 as inputs to not require an
      additional preprocessing step to cast all inputs from float32 to bfloat16.
      Consider this to be analogous to JPEG decoding in the data step.

      Args:
        example: TFExample
        spec_dict: Dictionary of feature name -> tf.FixedLenFeature

      Returns:
        Parsed feature map
      """

      def is_bfloat_feature(value):
        return value.dtype == tf.bfloat16

      def maybe_map_bfloat(value):
        """Maps bfloat16 to float32."""
        if is_bfloat_feature(value):
          if isinstance(value, tf.FixedLenFeature):
            return tf.FixedLenFeature(
                value.shape, tf.float32, default_value=value.default_value)
          elif isinstance(value, tf.VarLenFeature):
            return tf.VarLenFeature(
                value.shape, tf.float32, default_value=value.default_value)
          else:
            return tf.FixedLenSequenceFeature(
                value.shape, tf.float32, default_value=value.default_value)
        return value

      # Change bfloat features to float32 for parsing.
      new_spec_dict = {
          k: maybe_map_bfloat(v) for k, v in six.iteritems(spec_dict)
      }
      for k, v in six.iteritems(new_spec_dict):
        if v.dtype not in [tf.float32, tf.string, tf.int64]:
          raise ValueError('Feature specification with invalid data type for '
                           'tf.Example parsing: "%s": %s' % (k, v.dtype))

      # Separate new_spec_dict into Context and Sequence features. In the event
      # that there are no SequenceFeatures, the context_features dictionary
      # (containing FixedLenFeatures) is passed to tf.parse_examples.
      context_features, sequence_features = {}, {}
      for k, v in six.iteritems(new_spec_dict):
        v = maybe_map_bfloat(v)
        if isinstance(v, tf.FixedLenSequenceFeature):
          sequence_features[k] = v
        elif isinstance(v, tf.FixedLenFeature):
          context_features[k] = v
        elif isinstance(v, tf.VarLenFeature):
          context_features[k] = v
        else:
          raise ValueError(
              'Only FixedLenFeature and FixedLenSequenceFeature are currently '
              'supported.')

      # If there are any sequence features, we use parse_sequence_example.
      if sequence_features:
        # Filter out '_length' context features; don't parse them from records.
        for parse_name in sequence_features:
          # Sometimes, the '_length' context feature doesn't exist.
          if parse_name + '_length' in context_features:
            del context_features[parse_name + '_length']
        result, sequence_result, feature_lengths = tf.io.parse_sequence_example(
            example,
            context_features=context_features,
            sequence_features=sequence_features)
        result.update(sequence_result)
        # Augment the parsed tensors with feature length tensors.
        for parse_name, length_tensor in feature_lengths.items():
          result[parse_name + '_length'] = length_tensor
      else:
        result = tf.parse_example(example, context_features)
      to_convert = [
          k for k, v in six.iteritems(spec_dict) if is_bfloat_feature(v)
      ]

      for c in to_convert:
        result[c] = tf.cast(result[c], tf.bfloat16)

      return result

    prepend_keys = lambda d, pre: {pre + k: v for k, v in list(d.items())}
    # Parse each dataset's tensors. Parsed results from parse_wrapper get
    # dataset_key prepended to ensure uniqueness of keys among datasets.
    parsed_tensors = {}
    # {Prepended parsed key : TensorSpecs} for all datasets. Will contain
    # '_length' TensorSpecs that won't actually get parsed. We filter those out
    # before passing to the parse_sequence_example call.
    tensor_spec_dict = {}
    for dataset_key, example_proto in dict_extracted.items():
      # Parsed key to Feature Specs (retained only for this dataset).
      tensor_dict = {}
      sub_feature_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
          feature_tspec, dataset_key)
      feature_dict, feature_tspec_dict = (
          tensorspec_utils.tensorspec_to_feature_dict(
              sub_feature_tspec, decode_images=decode_images))
      tensor_dict.update(feature_dict)
      tensor_spec_dict.update(prepend_keys(feature_tspec_dict, dataset_key))
      if label_tspec is not None:
        sub_label_tspec = tensorspec_utils.filter_spec_structure_by_dataset(
            label_tspec, dataset_key)
        label_dict, label_tspec_dict = (
            tensorspec_utils.tensorspec_to_feature_dict(
                sub_label_tspec, decode_images=decode_images))
        tensor_dict.update(label_dict)
        tensor_spec_dict.update(prepend_keys(label_tspec_dict, dataset_key))
      for key, parsed in parse_wrapper(example_proto, tensor_dict).items():
        parsed_tensors[dataset_key + key] = parsed

    # At this point, all tensors have been parsed into a single flat map.
    # Interpret encoded images.
    def decode_image(key, raw_bytes):
      """Decodes single or batches of JPEG- or PNG-encoded string tensors.

      Args:
        key: String key specified in feature map.
        raw_bytes: String tensor to decode as JPEG or PNG.

      Returns:
        Decoded image tensor with shape specified by tensor spec.
      Raises:
        ValueError: If dtype other than uint8 or uint16 is supplied for image
        specs.
      """
      img_batch_dims = tf.shape(raw_bytes)
      # The spatial + channel dimensions of a single image, assumed to be the
      # last 3 entries of the image feature's tensor spec.
      if len(tensor_spec_dict[key].shape) < 3:
        raise ValueError(
            'Shape of tensor spec for image feature "%s" must '
            'be 3 dimensional (h, w, c), but is %s' %
            (tensor_spec_dict[key].name, tensor_spec_dict[key].shape))
      single_img_dims = tensor_spec_dict[key].shape[-3:]
      num_channels = single_img_dims[2]
      if num_channels not in [1, 3]:
        raise ValueError(
            'Last dimension of shape of tensor spec for image '
            'feature "%s" must 1 or 3, but the shape is %s' %
            (tensor_spec_dict[key].name, tensor_spec_dict[key].shape))

      # Collapse (possibly multiple) batch dims to a single batch dim for
      # decoding purposes.
      raw_bytes = tf.reshape(raw_bytes, [-1])
      data_type = tensor_spec_dict[key].dtype
      if data_type not in SUPPORTED_PIXEL_ENCODINGS:
        raise ValueError('Decoding an image requires tensorspec.data_type '
                         'to be uint8 or uint16.')

      def _decode_images(image_bytes):
        """Decode single image."""
        def _zero_image():
          return tf.zeros(single_img_dims, dtype=data_type)

        def _tf_decode_image():
          return tf.image.decode_image(
              image_bytes, channels=num_channels, dtype=data_type)

        image = tf.cond(
            tf.equal(image_bytes, ''), _zero_image, _tf_decode_image)
        image.set_shape(single_img_dims)
        return image

      img = tf.map_fn(
          _decode_images, raw_bytes, dtype=data_type, back_prop=False)
      img.set_shape(raw_bytes.shape.concatenate(single_img_dims))

      # Expand the collapsed batch dim back to the original img_batch_dims.
      img = tf.reshape(img, tf.concat([img_batch_dims, single_img_dims], 0))

      return img

    # Convert all sparse tensors to dense tensors.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensor_spec.varlen_default_value is not None:
        if tensorspec_utils.is_encoded_image_spec(tensor_spec):
          default_value = ''
        else:
          default_value = tf.cast(
              tf.constant(tensor_spec.varlen_default_value),
              dtype=tensor_spec.dtype)
        parsed_tensors[key] = tf.sparse.to_dense(
            val, default_value=default_value)

    # Ensure that all images are properly decoded.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensorspec_utils.is_encoded_image_spec(tensor_spec) and decode_images:
        parsed_tensors[key] = decode_image(key, val)
        if tensor_spec.dtype not in SUPPORTED_PIXEL_ENCODINGS:
          raise ValueError('Encoded images with key {} must be '
                           'specified with uint8 or uint16 dtype.'.format(key))

    # Pad all varlen features to the corrensponding spec.
    for key, val in parsed_tensors.items():
      tensor_spec = tensor_spec_dict[key]
      if tensor_spec.varlen_default_value is not None:
        parsed_tensors[key] = tensorspec_utils.pad_or_clip_tensor_to_spec_shape(
            val, tensor_spec)

    # Ensure that we have a consistent ordered mapping despite the underlying
    # spec structure.
    flat_feature_tspec = tensorspec_utils.TensorSpecStruct(
        sorted(tensorspec_utils.flatten_spec_structure(feature_tspec).items()))
    # Using the flat spec structure we allow to map the same parsed_tensor
    # to multiple features or labels. Note, the spec structure ensures that
    # the corresponding tensorspecs are iddentical in such cases.
    features = tensorspec_utils.TensorSpecStruct([
        (key, parsed_tensors[value.dataset_key + value.name])
        for key, value in flat_feature_tspec.items()
    ])

    features = tensorspec_utils.validate_and_pack(
        flat_feature_tspec, features, ignore_batch=True)
    if label_tspec is not None:
      # Ensure that we have a consistent ordered mapping despite the underlying
      # spec structure.
      flat_label_tspec = tensorspec_utils.TensorSpecStruct(
          sorted(tensorspec_utils.flatten_spec_structure(label_tspec).items()))
      labels = tensorspec_utils.TensorSpecStruct([
          (key, parsed_tensors[value.dataset_key + value.name])
          for key, value in flat_label_tspec.items()
      ])
      labels = tensorspec_utils.validate_and_pack(
          flat_label_tspec, labels, ignore_batch=True)
      return features, labels
    return features
예제 #30
0
    def inference_network_fn(self,
                             features,
                             labels,
                             mode,
                             config=None,
                             params=None):
        """The inference network implementation.

    Args:
      features: This is the first item returned from the input_fn and parsed
        by tensorspec_utils.validate_and_pack. A spec_structure which fulfills
        the requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed
        by tensorspec_utils.validate_and_pack. A spec_structure which fulfills
        the requirements of the self.get_feature_specification.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig)
        Will receive what is passed to Estimator in config parameter, or the
        default config (tf.estimator.RunConfig). Allows updating things in your
        model_fn based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator,
        including 'batch_size'.
    Returns:
      predictions: A dict with output tensors.
    """
        del config

        maml_inner_loop_instance = maml_inner_loop.MAMLInnerLoopGradientDescent(
        )

        def task_learn(inputs_list):
            """Meta-learning for an individual task, for use with map_fn.

      Args:
          inputs_list: A list of [(condition_features,  condition_labels), ...,
            (inference_features, inference_labels)] individual tasks.

      Returns:
          condition_output: Output of model on conditioning data, before weight
            update.
          inference_output: Output of model on evaluation data, after weight
            update.
      """
            # Disable a_func's summary creation in the inner loop of MAML, since
            # summaries are not supported inside while_loop.
            inner_loop_params = copy.deepcopy(params)
            if inner_loop_params is None:
                inner_loop_params = {}
            inner_loop_params['use_summaries'] = False
            inner_loop_params['maml_inner_loop'] = True

            inference_output, condition_outputs, inner_loss = (
                maml_inner_loop_instance.inner_loop(
                    inputs_list=inputs_list,
                    inference_network_fn=self._base_model.inference_network_fn,
                    model_train_fn=self._base_model.model_train_fn,
                    mode=mode,
                    params=inner_loop_params))
            return inference_output, condition_outputs, inner_loss

        # Since we need the same format for the mapping function to be fed in any
        # circumstance we overwrite the unused_inference_labels which as the name
        # suggests are not used during the inner loop.
        unused_inference_labels = labels
        if labels is None:
            unused_inference_labels = features.condition.labels
        elems = (
            (features.condition.features,
             features.condition.labels), ) * self._num_inner_loop_steps + (
                 (features.inference.features, unused_inference_labels), )

        # inference output refers to the output we typically optimize with MAML.
        # condition_output refers to the inner loop outputs which could also be
        # further optimized, but in standard MAML are assumed to be simple
        # gradient descent steps of some form. This does NOT play well with batch
        # norm currently due to use of while_loop.
        inference_output, condition_output, inner_loss = self._map_task_learn(
            task_learn, elems, mode, params)

        if self.use_summaries(params):
            maml_inner_loop_instance.add_parameter_summaries()
            for index, inner_loss_step in enumerate(inner_loss):
                tf.summary.scalar('inner_loss_{}'.format(index),
                                  tf.reduce_mean(inner_loss_step))

        # Note, this is the first iteration output and loss, prior to any
        # adaptation. In total we have num_inner_loop_steps + 1 since we do one more
        # forward pass for which we do not compute and apply gradients. This step is
        # to monitor the effect of the inner loop.
        base_condition_output = condition_output[0]
        unconditioned_inference_output = inference_output[0]
        conditioned_inference_output = inference_output[1]

        predictions = utils.TensorSpecStruct()

        # We keep the full outputs such that we can simply call the
        # model_condition_fn of the base model.
        predictions.full_condition_output = (utils.TensorSpecStruct(
            list(base_condition_output.items())))

        for pos, base_condition_output in enumerate(condition_output):
            predictions['full_condition_outputs/output_{}'.format(pos)] = (
                utils.TensorSpecStruct(list(base_condition_output.items())))

        predictions.full_inference_output_unconditioned = (
            utils.TensorSpecStruct(list(
                unconditioned_inference_output.items())))
        predictions.full_inference_output = (utils.TensorSpecStruct(
            list(conditioned_inference_output.items())))
        if self.use_summaries(params):
            for key, inference in predictions.items():
                tf.summary.histogram(key, inference)
            for key in unconditioned_inference_output.keys():
                delta = (conditioned_inference_output[key] -
                         unconditioned_inference_output[key])
                tf.summary.histogram('delta/{}'.format(key), delta)

        predictions = self._select_inference_output(predictions)
        if 'condition_output' not in predictions:
            raise ValueError(
                'The required condition_output is not in predictions {}.'.
                format(list(predictions.keys())))
        if 'inference_output' not in predictions:
            raise ValueError(
                'The required inference_output is not in predictions {}.'.
                format(list(predictions.keys())))
        return predictions