def create_maml_feature_spec(feature_spec, label_spec): """Create a meta feature from existing base_model specs. Note, the train spec will maintain the same name and thus mapping to the input. This is important to create the parse_tf_example_fn automatically. The validation data will have a val/ prefix such that we can feed different data to both inputs. Args: feature_spec: A hierarchy of TensorSpecs(subclasses) or Tensors. label_spec: A hierarchy of TensorSpecs(subclasses) or Tensors. Returns: An instance of TensorSpecStruct representing a valid meta learning tensor_spec with .condition and .inference access. """ condition_spec = TSpecStructure() condition_spec.features = utils.flatten_spec_structure( utils.copy_tensorspec(feature_spec, batch_size=-1, prefix='condition_features')) condition_spec.labels = utils.flatten_spec_structure( utils.copy_tensorspec(label_spec, batch_size=-1, prefix='condition_labels')) inference_spec = TSpecStructure() inference_spec.features = utils.flatten_spec_structure( utils.copy_tensorspec(feature_spec, batch_size=-1, prefix='inference_features')) meta_feature_spec = TSpecStructure() meta_feature_spec.condition = condition_spec meta_feature_spec.inference = inference_spec return meta_feature_spec
def _create_meta_spec(tensor_spec, spec_type, num_train_samples_per_task, num_val_samples_per_task): """Create a TrainValPair from an existing spec. Note, the train spec will maintain the same name and thus mapping to the input. This is important to create the parse_tf_example_fn automatically. The validation data will have a val/ prefix such that we can feed different data to both inputs. Args: tensor_spec: A dict, (named)tuple, list or a hierarchy thereof filled by TensorSpecs(subclasses) or Tensors. spec_type: A string ['features', 'labels'] specifying which spec type we alter in order to introduce the corresponding val_mode. num_train_samples_per_task: Number of training samples to expect per task batch element. num_val_samples_per_task: Number of val examples to expect per task batch element. Raises: ValueError: If the spec_type is not in ['features', 'labels']. Returns: An instance of TensorSpecStruct representing a valid meta learning tensor_spec with .train and .val access. """ if spec_type not in ['features', 'labels']: raise ValueError('We only support spec_type "features" or "labels" ' 'but received {}.'.format(spec_type)) train_tensor_spec = utils.flatten_spec_structure( utils.copy_tensorspec(tensor_spec, batch_size=num_train_samples_per_task, prefix='train')) # Since the train part is also required for inference, the specs cannot be # optional. for key, value in train_tensor_spec.items(): train_tensor_spec[key] = utils.ExtendedTensorSpec.from_spec( value, is_optional=False) val_tensor_spec = utils.flatten_spec_structure( utils.copy_tensorspec(tensor_spec, batch_size=num_val_samples_per_task, prefix='val')) # Since the train part is also required for inference, the specs for # val cannot be optional because the inputs to a while loop have to be # the same for every step of the loop. for key, value in val_tensor_spec.items(): val_tensor_spec[key] = utils.ExtendedTensorSpec.from_spec( value, is_optional=False) val_mode_shape = (1, ) if num_train_samples_per_task is None: val_mode_shape = () val_mode = TensorSpec(shape=val_mode_shape, dtype=tf.bool, name='val_mode/{}'.format(spec_type)) return utils.flatten_spec_structure( TrainValPair(train=train_tensor_spec, val=val_tensor_spec, val_mode=val_mode))
def get_label_specification(self, mode): del mode action_spec = TensorSpec( shape=(self._action_size,), dtype=tf.float32, name='action_world') tspec = tensorspec_utils.TensorSpecStruct(action=action_spec) return tensorspec_utils.copy_tensorspec( tspec, batch_size=self._episode_length)
def _episode_label_specification(self, mode): """Returns the label spec for a single episode.""" del mode action_spec = TensorSpec(shape=(self._num_waypoints * self._action_size, ), dtype=tf.float32, name='action_world') tspec = tensorspec_utils.TensorSpecStruct(action=action_spec) return tensorspec_utils.copy_tensorspec( tspec, batch_size=self._episode_length)
def test_copy_none_name(self): spec = utils.TensorSpecStruct() spec.none_name = utils.ExtendedTensorSpec(shape=(1,), dtype=tf.float32) spec.with_name = utils.ExtendedTensorSpec( shape=(2,), dtype=tf.float32, name='with_name') spec_copy = utils.copy_tensorspec(spec, prefix='test') # Spec equality does not check the name utils.assert_equal(spec, spec_copy) self.assertEqual(spec_copy.none_name.name, 'test/') self.assertEqual(spec_copy.with_name.name, 'test/with_name')
def _episode_feature_specification(self, mode): """Returns the feature spec for a single episode.""" del mode full_state_pose_spec = TensorSpec(shape=(self._obs_size, ), dtype=tf.float32, name='full_state_pose') spec = tensorspec_utils.TensorSpecStruct( full_state_pose=full_state_pose_spec) spec = tensorspec_utils.copy_tensorspec( spec, batch_size=self._episode_length) return spec
def get_in_feature_specification(self, mode): """See base class.""" feature_spec = tensorspec_utils.copy_tensorspec( self._model_feature_specification_fn(mode)) true_img_shape = feature_spec.image.shape.as_list() true_img_shape[ -3:-1] = self._src_img_res # Overwrite the H, W dimensions. feature_spec.image = TensorSpec.from_spec(feature_spec.image, shape=true_img_shape, dtype=tf.uint8) return tensorspec_utils.flatten_spec_structure(feature_spec)
def create_maml_label_spec(label_spec): """Create a meta feature from existing base_model specs. Args: label_spec: A hierarchy of TensorSpecs(subclasses) or Tensors. Returns: An instance of TensorSpecStruct representing a valid meta learning tensor_spec for computing the outer loss. """ return utils.flatten_spec_structure( utils.copy_tensorspec(label_spec, batch_size=-1, prefix='meta_labels'))
def get_feature_specification(self, mode): del mode image_spec = TensorSpec( shape=(100, 100, 3), dtype=tf.float32, name='image0', data_format='jpeg') gripper_pose_spec = TensorSpec( shape=(14,), dtype=tf.float32, name='world_pose_gripper') tspec = tensorspec_utils.TensorSpecStruct( image=image_spec, gripper_pose=gripper_pose_spec) return tensorspec_utils.copy_tensorspec( tspec, batch_size=self._episode_length)
def get_in_feature_specification(self, mode ): """See base class.""" feature_spec = tensorspec_utils.copy_tensorspec( self._model_feature_specification_fn(mode)) # Don't want to parse the original_image, since we don't want to parse it # and we are adding this feature in preprocess_fn to satisfy the model's # inputs. if mode != PREDICT and 'original_image' in feature_spec: del feature_spec['original_image'] if 'image' in feature_spec: true_img_shape = feature_spec.image.shape.as_list() # Overwrite the H, W dimensions. true_img_shape[-3:-1] = self._src_img_res feature_spec.image = TensorSpec.from_spec( feature_spec.image, shape=true_img_shape, dtype=tf.uint8) return tensorspec_utils.flatten_spec_structure(feature_spec)
def test_copy(self, collection_type): spec = self._make_tensorspec_collection(collection_type) spec_copy = utils.copy_tensorspec(spec) utils.assert_equal(spec, spec_copy)