def test_from_spec(self): spec_1 = utils.ExtendedTensorSpec((1, 2), tf.int32) spec_2 = utils.ExtendedTensorSpec.from_spec(spec_1) self.assertEqual(spec_1, spec_2) # We make sure that we can actually overwrite the name. spec_1 = utils.ExtendedTensorSpec((1, 2), tf.int32, name='spec_1') spec_2 = utils.ExtendedTensorSpec.from_spec(spec_1, name='spec_2') # The name is not checked when we check for equality so it should still # pass. That is the default behavior of TensorSpec, therefore, we want to # maintain this behavior. self.assertEqual(spec_1, spec_2) self.assertEqual(spec_1.name, 'spec_1') self.assertEqual(spec_2.name, 'spec_2') # Add batch dimension. spec_2 = utils.ExtendedTensorSpec.from_spec(spec_1, batch_size=16) self.assertNotEqual(spec_1, spec_2) self.assertEqual(spec_1.shape, spec_2.shape[1:]) self.assertEqual(spec_2.shape[0].value, 16) # Add batch dimension. spec_2 = utils.ExtendedTensorSpec.from_spec(spec_1, batch_size=-1) self.assertEqual(spec_2.shape[1:], spec_1.shape) self.assertIsNone(spec_2.shape[0].value) # Sequential. spec_1 = utils.ExtendedTensorSpec((1, 2), tf.int32, is_sequence=True) spec_2 = utils.ExtendedTensorSpec.from_spec(spec_1, batch_size=-1) self.assertEqual(spec_2.shape[1:], spec_1.shape) self.assertTrue(spec_2.is_sequence)
def _get_feature_specification(self): spec = tensorspec_utils.TensorSpecStruct() spec.action = tensorspec_utils.ExtendedTensorSpec( name='action', shape=(1,), dtype=tf.float32) spec.velocity = tensorspec_utils.ExtendedTensorSpec( name='velocity', shape=(1,), dtype=tf.float32, is_optional=True) return spec
def _get_label_specification(self): spec = tensorspec_utils.TensorSpecStruct() spec.target = tensorspec_utils.ExtendedTensorSpec( name='target', shape=(1,), dtype=tf.float32) spec.proxy = tensorspec_utils.ExtendedTensorSpec( name='proxy', shape=(1,), dtype=tf.float32, is_optional=True) return spec
def test_shape_compatibility(self): unknown = tf.placeholder(tf.int64) partial = tf.placeholder(tf.int64, shape=[None, 1]) full = tf.placeholder(tf.int64, shape=[2, 3]) rank3 = tf.placeholder(tf.int64, shape=[4, 5, 6]) desc_unknown = utils.ExtendedTensorSpec(None, tf.int64) self.assertTrue(desc_unknown.is_compatible_with(unknown)) self.assertTrue(desc_unknown.is_compatible_with(partial)) self.assertTrue(desc_unknown.is_compatible_with(full)) self.assertTrue(desc_unknown.is_compatible_with(rank3)) desc_partial = utils.ExtendedTensorSpec([2, None], tf.int64) self.assertTrue(desc_partial.is_compatible_with(unknown)) self.assertTrue(desc_partial.is_compatible_with(partial)) self.assertTrue(desc_partial.is_compatible_with(full)) self.assertFalse(desc_partial.is_compatible_with(rank3)) desc_full = utils.ExtendedTensorSpec([2, 3], tf.int64) self.assertTrue(desc_full.is_compatible_with(unknown)) self.assertFalse(desc_full.is_compatible_with(partial)) self.assertTrue(desc_full.is_compatible_with(full)) self.assertFalse(desc_full.is_compatible_with(rank3)) desc_rank3 = utils.ExtendedTensorSpec([4, 5, 6], tf.int64) self.assertTrue(desc_rank3.is_compatible_with(unknown)) self.assertFalse(desc_rank3.is_compatible_with(partial)) self.assertFalse(desc_rank3.is_compatible_with(full)) self.assertTrue(desc_rank3.is_compatible_with(rank3))
def test_stack_intratask_episodes(self): feature_spec = TSpec() feature_spec.image = utils.ExtendedTensorSpec( shape=_DEFAULT_IN_IMAGE_SHAPE, dtype=tf.uint8, is_optional=False, data_format='jpeg', name='state/image') feature_spec.action = utils.ExtendedTensorSpec( shape=_DEFAULT_ACTION_SHAPE, dtype=tf.float32, is_optional=False, name='state/action') batch_size = 2 num_samples_in_task = 3 metaexample_spec = preprocessors.create_metaexample_spec( feature_spec, num_samples_in_task, 'condition') tensors = utils.make_random_numpy(metaexample_spec, batch_size) out_tensors = preprocessors.stack_intra_task_episodes( tensors, num_samples_in_task) self.assertEqual( out_tensors.image.shape, (batch_size, num_samples_in_task) + _DEFAULT_IN_IMAGE_SHAPE) self.assertEqual( out_tensors.action.shape, (batch_size, num_samples_in_task) + _DEFAULT_ACTION_SHAPE)
def _test_multi_record_input_generator(self, input_generator, is_dataset=False): feature_spec = tensorspec_utils.TensorSpecStruct() feature_spec.state = tensorspec_utils.ExtendedTensorSpec( shape=(64, 64, 3), dtype=tf.uint8, name='state/image', data_format='jpeg', dataset_key='d1') feature_spec.action = tensorspec_utils.ExtendedTensorSpec( shape=(2), dtype=tf.float32, name='pose', dataset_key='d1') label_spec = tensorspec_utils.TensorSpecStruct() label_spec.reward = tensorspec_utils.ExtendedTensorSpec( shape=(), dtype=tf.float32, name='reward', dataset_key='d1') label_spec.reward_2 = tensorspec_utils.ExtendedTensorSpec( shape=(), dtype=tf.float32, name='reward', dataset_key='d2') input_generator.set_feature_specifications(feature_spec, feature_spec) input_generator.set_label_specifications(label_spec, label_spec) np_features, np_labels = input_generator.create_dataset_input_fn( mode=tf.estimator.ModeKeys.TRAIN)().make_one_shot_iterator( ).get_next() np_features = tensorspec_utils.validate_and_pack(feature_spec, np_features, ignore_batch=True) np_labels = tensorspec_utils.validate_and_pack(label_spec, np_labels, ignore_batch=True) self.assertAllEqual([2, 64, 64, 3], np_features.state.shape) self.assertAllEqual([2, 2], np_features.action.shape) self.assertAllEqual((2, ), np_labels.reward.shape) self.assertAllEqual((2, ), np_labels.reward_2.shape)
def test_create_metaexample_spec(self): feature_spec = TSpec() feature_spec.image = utils.ExtendedTensorSpec( shape=_DEFAULT_IN_IMAGE_SHAPE, dtype=tf.uint8, is_optional=False, data_format='jpeg', name='state/image') feature_spec.action = utils.ExtendedTensorSpec( shape=_DEFAULT_ACTION_SHAPE, dtype=tf.float32, is_optional=False, name='state/action') num_samples_in_task = 3 metaexample_spec = preprocessors.create_metaexample_spec( feature_spec, num_samples_in_task, 'condition') flat_feature_spec = utils.flatten_spec_structure(feature_spec) self.assertLen( list(metaexample_spec.keys()), num_samples_in_task * len(list(flat_feature_spec.keys()))) for key in flat_feature_spec: for i in range(num_samples_in_task): meta_example_key = six.ensure_str(key) + '/{:d}'.format(i) self.assertIn(meta_example_key, list(metaexample_spec.keys())) self.assertTrue( six.ensure_str(metaexample_spec[meta_example_key].name).startswith( 'condition_ep'))
def test_sequence_parsing(self, batch_size): file_pattern = os.path.join(FLAGS.test_tmpdir, 'test.tfrecord') sequence_length = 3 if not os.path.exists(file_pattern): self._write_test_sequence_examples(sequence_length, file_pattern) dataset = tfdata.parallel_read(file_patterns=file_pattern) # Features state_spec_1 = tensorspec_utils.ExtendedTensorSpec( shape=(TEST_IMAGE_SHAPE), dtype=tf.uint8, is_sequence=True, name='image_sequence_feature', data_format='JPEG') state_spec_2 = tensorspec_utils.ExtendedTensorSpec( shape=(2), dtype=tf.float32, is_sequence=True, name='sequence_feature') feature_tspec = PoseEnvFeature(state=state_spec_1, action=state_spec_2) feature_tspec = tensorspec_utils.add_sequence_length_specs( feature_tspec) # Labels reward_spec = tensorspec_utils.ExtendedTensorSpec( shape=(), dtype=tf.int64, is_sequence=False, name='context_feature') label_tspec = PoseEnvLabel(reward=reward_spec) label_tspec = tensorspec_utils.add_sequence_length_specs(label_tspec) dataset = dataset.batch(batch_size, drop_remainder=True) dataset = tfdata.serialized_to_parsed(dataset, feature_tspec, label_tspec) features, labels = dataset.make_one_shot_iterator().get_next() # Check tensor shapes. self.assertAllEqual([batch_size, None] + TEST_IMAGE_SHAPE, features.state.shape.as_list()) self.assertAllEqual([batch_size, None, 2], features.action.shape.as_list()) self.assertAllEqual([batch_size], features.state_length.shape.as_list()) self.assertAllEqual([batch_size], features.action_length.shape.as_list()) self.assertAllEqual([batch_size], labels.reward.shape.as_list()) with self.session() as session: features_, labels_ = session.run([features, labels]) # Check that images are equal. for i in range(3): img = TEST_IMAGE * i self.assertAllEqual(img, features_.state[0, i]) # Check that numpy shapes are equal. self.assertAllEqual([batch_size, sequence_length] + TEST_IMAGE_SHAPE, features_.state.shape) self.assertAllEqual([sequence_length] * batch_size, features_.state_length) self.assertAllEqual([batch_size, sequence_length, 2], features_.action.shape) self.assertAllEqual([batch_size], labels_.reward.shape)
def test_copy_none_name(self): spec = utils.TensorSpecStruct() spec.none_name = utils.ExtendedTensorSpec(shape=(1,), dtype=tf.float32) spec.with_name = utils.ExtendedTensorSpec( shape=(2,), dtype=tf.float32, name='with_name') spec_copy = utils.copy_tensorspec(spec, prefix='test') # Spec equality does not check the name utils.assert_equal(spec, spec_copy) self.assertEqual(spec_copy.none_name.name, 'test/') self.assertEqual(spec_copy.with_name.name, 'test/with_name')
def get_label_specification( self, mode): """See base class documentation.""" del mode spec_structure = tensorspec_utils.TensorSpecStruct() if self._multi_dataset: spec_structure.y = tensorspec_utils.ExtendedTensorSpec( shape=(1,), dtype=tf.float32, name='valid_position', dataset_key='dataset1') else: spec_structure.y = tensorspec_utils.ExtendedTensorSpec( shape=(1,), dtype=tf.float32, name='valid_position') return spec_structure
def test_pack_flat_sequence_to_spec_structure_ensure_order(self): test_spec = utils.TensorSpecStruct() test_spec.b = utils.ExtendedTensorSpec( shape=(1,), dtype=tf.float32, name='b') test_spec.a = utils.ExtendedTensorSpec( shape=(1,), dtype=tf.float32, name='a') test_spec.c = utils.ExtendedTensorSpec( shape=(1,), dtype=tf.float32, name='c') placeholders = utils.make_placeholders(test_spec) packed_placeholders = utils.pack_flat_sequence_to_spec_structure( test_spec, placeholders) for pos, order_name in enumerate(['a', 'b', 'c']): self.assertEqual(list(packed_placeholders.keys())[pos], order_name) self.assertEqual( list(packed_placeholders.values())[pos].op.name, order_name)
def get_in_feature_specification(self, mode): del mode feature_spec = TSpec() feature_spec.image = utils.ExtendedTensorSpec( shape=_DEFAULT_IN_IMAGE_SHAPE, dtype=tf.uint8, is_optional=False, data_format='jpeg', name='state/image') feature_spec.action = utils.ExtendedTensorSpec( shape=_DEFAULT_ACTION_SHAPE, dtype=tf.float32, is_optional=False, name='state/action') return feature_spec
def test_pad_or_clip_tensor_to_spec_shape(self, input_data, expected_output): varlen_spec = utils.ExtendedTensorSpec(shape=(3, ), dtype=tf.int64, name='varlen', varlen_default_value=3.0) tmp_dir = self.create_tempdir().full_path file_path_padded_to_size_two = os.path.join(tmp_dir, 'size_two.tfrecord') self._write_test_examples(input_data, file_path_padded_to_size_two) dataset = tf.data.TFRecordDataset( filenames=tf.constant([file_path_padded_to_size_two])) dataset = dataset.batch(len(input_data), drop_remainder=True) def parse_fn(example): return tf.parse_example(example, {'varlen': tf.VarLenFeature(tf.int64)}) dataset = dataset.map(parse_fn) sparse_tensors = dataset.make_one_shot_iterator().get_next()['varlen'] default_value = tf.cast(tf.constant(varlen_spec.varlen_default_value), dtype=varlen_spec.dtype) tensor = utils.pad_or_clip_tensor_to_spec_shape( tf.sparse.to_dense(sparse_tensors, default_value), varlen_spec) with self.session() as sess: np_tensor = sess.run(tensor) self.assertAllEqual(np_tensor, np.array(expected_output))
def test_varlen_default_value_raise(self): with self.assertRaises(ValueError): # This raises since only rank 1 tensors are supported for varlen. utils.ExtendedTensorSpec(shape=(3, 2), dtype=tf.int64, name='varlen', varlen_default_value=3.0)
def get_out_feature_specification(self, mode): del mode feature_spec = TSpec() feature_spec.image = utils.ExtendedTensorSpec( shape=_DEFAULT_OUT_IMAGE_SHAPE, dtype=tf.float32, is_optional=False, name='state/image') feature_spec.original_image = utils.ExtendedTensorSpec( shape=_DEFAULT_IN_IMAGE_SHAPE, dtype=tf.float32, is_optional=True) feature_spec.action = utils.ExtendedTensorSpec( shape=_DEFAULT_ACTION_SHAPE, dtype=tf.float32, is_optional=False, name='state/action') return feature_spec
def test_images_decoding(self, np_data_type, tf_data_type): file_pattern = os.path.join(self.create_tempdir().full_path, 'test.tfrecord') image_width = 640 image_height = 512 maxval = np.iinfo(np_data_type).max # Maximum value for byte-encoded image. image_np = np.random.uniform( size=(image_height, image_width), high=maxval).astype(np.int32) png_encoded_image = image.numpy_to_image_string(image_np, 'png', np_data_type) test_data = [[png_encoded_image]] self._write_test_images_examples(test_data, file_pattern) feature_spec = tensorspec_utils.TensorSpecStruct() feature_spec.images = tensorspec_utils.ExtendedTensorSpec( shape=(image_height, image_width, 1), dtype=tf_data_type, name='image/encoded', data_format='png') dataset = tfdata.parallel_read(file_patterns=file_pattern) dataset = dataset.batch(1, drop_remainder=True) if np_data_type == np.uint32: with self.assertRaises(tf.errors.InvalidArgumentError): dataset = tfdata.serialized_to_parsed(dataset, feature_spec, None) else: dataset = tfdata.serialized_to_parsed(dataset, feature_spec, None) features = dataset.make_one_shot_iterator().get_next() # Check tensor shapes. self.assertAllEqual( [1, image_height, image_width, 1], features.images.get_shape().as_list()) with self.session() as session: np_features = session.run(features) self.assertEqual(np_features['images'].dtype, np_data_type)
def get_feature_specification(self, mode): """See base class documentation.""" del mode spec_structure = tensorspec_utils.TensorSpecStruct() spec_structure.x = tensorspec_utils.ExtendedTensorSpec( shape=(3, ), dtype=tf.float32, name='measured_position') return spec_structure
def test_varlen_images_feature_spec_raises(self, batch_size): file_pattern = os.path.join(self.create_tempdir().full_path, 'test.tfrecord') image_width = 640 image_height = 512 padded_varlen_size = 3 maxval = 255 # Maximum value for byte-encoded image. image_np = np.random.uniform( size=(image_height, image_width), high=maxval).astype(np.int32) image_with_invalid_size = np.ones((1024, 1280)) * 255 png_encoded_image = image.numpy_to_image_string(image_np, 'png') png_encoded_image_with_invalid_size = image.numpy_to_image_string( image_with_invalid_size, 'png') test_data = [[png_encoded_image_with_invalid_size], [png_encoded_image, png_encoded_image_with_invalid_size]] self._write_test_varlen_images_examples(test_data, file_pattern) feature_spec = tensorspec_utils.TensorSpecStruct() feature_spec.varlen_images = tensorspec_utils.ExtendedTensorSpec( shape=(padded_varlen_size, image_height, image_width, 1), dtype=tf.uint8, name='varlen_images', data_format='png', varlen_default_value=0) dataset = tfdata.parallel_read(file_patterns=file_pattern) dataset = dataset.batch(batch_size, drop_remainder=True) dataset = tfdata.serialized_to_parsed(dataset, feature_spec, None) features = dataset.make_one_shot_iterator().get_next() # Check tensor shapes. self.assertAllEqual( [None, padded_varlen_size, image_height, image_width, 1], features.varlen_images.get_shape().as_list()) with self.session() as session: with self.assertRaises(tf.errors.InvalidArgumentError): session.run(features)
def test_pad_sparse_tensor_to_spec_shape_raises(self): varlen_spec = utils.ExtendedTensorSpec(shape=(3, ), dtype=tf.int64, name='varlen', varlen_default_value=3.0) tmp_dir = self.create_tempdir().full_path file_path_padded_to_size_two = os.path.join(tmp_dir, 'size_two.tfrecord') # This will raise because the desired max shape is 3 but we create an # example with shape 4. test_data = [[1, 2, 3, 4]] self._write_test_examples(test_data, file_path_padded_to_size_two) dataset = tf.data.TFRecordDataset( filenames=tf.constant([file_path_padded_to_size_two])) dataset = dataset.batch(len(test_data), drop_remainder=True) def parse_fn(example): return tf.parse_example(example, {'varlen': tf.VarLenFeature(tf.int64)}) dataset = dataset.map(parse_fn) sparse_tensors = dataset.make_one_shot_iterator().get_next()['varlen'] tensor = utils.pad_sparse_tensor_to_spec_shape(sparse_tensors, varlen_spec) with self.session() as sess: with self.assertRaises(tf.errors.InvalidArgumentError): sess.run(tensor)
def get_feature_specification( self, mode): """See base class documentation.""" del mode spec_structure = tensorspec_utils.TensorSpecStruct() if self._multi_dataset: spec_structure.x1 = tensorspec_utils.ExtendedTensorSpec( shape=(3,), dtype=tf.float32, name='measured_position', dataset_key='dataset1') spec_structure.x2 = tensorspec_utils.ExtendedTensorSpec( shape=(3,), dtype=tf.float32, name='measured_position', dataset_key='dataset2') else: spec_structure.x = tensorspec_utils.ExtendedTensorSpec( shape=(3,), dtype=tf.float32, name='measured_position') return spec_structure
def test_repr(self): desc1 = utils.ExtendedTensorSpec([1], tf.float32, name='beep', is_optional=True, data_format='jpeg') self.assertEqual( repr(desc1), "ExtendedTensorSpec(shape=(1,), dtype=tf.float32, name='beep', " "is_optional=True, is_sequence=False, is_extracted=False, " "data_format='jpeg', dataset_key='')") desc2 = utils.ExtendedTensorSpec([1, None], tf.int32, is_sequence=True) self.assertEqual( repr(desc2), "ExtendedTensorSpec(shape=(1, ?), dtype=tf.int32, name=None, " "is_optional=False, is_sequence=True, is_extracted=False, " "data_format=None, dataset_key='')")
def test_varlen_default_value_raise(self): with self.assertRaises(ValueError): # This raises since only rank 1 tensors are supported for varlen without # images. utils.ExtendedTensorSpec(shape=(3, 2, 4, 1), dtype=tf.int64, name='varlen', varlen_default_value=3.0) with self.assertRaises(ValueError): # This raises since only rank 4 tensors are supported for varlen with # images. utils.ExtendedTensorSpec(shape=(3), dtype=tf.int64, name='varlen', varlen_default_value=3.0, data_format='png')
def test_is_optional(self, is_optional): desc = utils.ExtendedTensorSpec( shape=[1], dtype=np.float32, is_optional=is_optional) self.assertEqual(desc.is_optional, is_optional) desc_copy = utils.ExtendedTensorSpec.from_spec(desc) self.assertEqual(desc_copy.is_optional, is_optional) desc_overwrite = utils.ExtendedTensorSpec.from_spec( desc, is_optional=not is_optional) self.assertEqual(desc_overwrite.is_optional, not is_optional)
def test_data_format(self, data_format): desc = utils.ExtendedTensorSpec( shape=[1], dtype=np.float32, data_format=data_format) self.assertEqual(desc.data_format, data_format) desc_copy = utils.ExtendedTensorSpec.from_spec(desc) self.assertEqual(desc_copy.data_format, data_format) desc_overwrite = utils.ExtendedTensorSpec.from_spec( desc, data_format='NO_FORMAT') self.assertEqual(desc_overwrite.data_format, 'NO_FORMAT')
def get_out_label_specification(self, mode): del mode label_spec = TSpec() label_spec.reward = utils.ExtendedTensorSpec( shape=_DEFAULT_REWARD_SHAPE, dtype=tf.float32, is_optional=False, name='reward') return label_spec
def test_repr(self): desc1 = utils.ExtendedTensorSpec([1, 512, 640, 3], tf.float32, name='beep', is_optional=True, data_format='jpeg', varlen_default_value=1) self.assertEqual( repr(desc1), 'ExtendedTensorSpec(shape=(1, 512, 640, 3), dtype=tf.float32, ' "name='beep', is_optional=True, is_sequence=False, is_extracted=False, " "data_format='jpeg', dataset_key='', varlen_default_value=1)") desc2 = utils.ExtendedTensorSpec([1, None], tf.int32, is_sequence=True) self.assertEqual( repr(desc2), "ExtendedTensorSpec(shape=(1, ?), dtype=tf.int32, name=None, " "is_optional=False, is_sequence=True, is_extracted=False, " "data_format=None, dataset_key='', varlen_default_value=None)")
def test_varlen_feature_spec(self, batch_size): file_pattern = os.path.join(self.create_tempdir().full_path, 'test.tfrecord') test_data = [[1], [1, 2]] self._write_test_varlen_examples(test_data, file_pattern) feature_spec = tensorspec_utils.TensorSpecStruct() feature_spec.varlen = tensorspec_utils.ExtendedTensorSpec( shape=(3,), dtype=tf.int64, name='varlen', varlen_default_value=3.0) dataset = tfdata.parallel_read(file_patterns=file_pattern) dataset = dataset.batch(batch_size, drop_remainder=True) dataset = tfdata.serialized_to_parsed(dataset, feature_spec, None) features = dataset.make_one_shot_iterator().get_next() # Check tensor shapes. self.assertAllEqual([None, 3], features.varlen.get_shape().as_list()) with self.session() as session: np_features = session.run(features) self.assertAllEqual(np_features.varlen, np.array([[1, 3, 3], [1, 2, 3]][:batch_size])) self.assertAllEqual([batch_size, 3], np_features.varlen.shape)
def test_images_decoding_raises(self): file_pattern = os.path.join(self.create_tempdir().full_path, 'test.tfrecord') image_width = 640 image_height = 512 maxval = np.iinfo(np.uint32).max # Maximum value for byte-encoded image. image_np = np.random.uniform( size=(image_height, image_width), high=maxval).astype(np.int32) png_encoded_image = image.numpy_to_image_string(image_np, 'png', np.uint32) test_data = [[png_encoded_image]] self._write_test_images_examples(test_data, file_pattern) feature_spec = tensorspec_utils.TensorSpecStruct() feature_spec.images = tensorspec_utils.ExtendedTensorSpec( shape=(image_height, image_width, 1), dtype=tf.uint32, name='image/encoded', data_format='png') dataset = tfdata.parallel_read(file_patterns=file_pattern) dataset = dataset.batch(1, drop_remainder=True) with self.assertRaises(ValueError): tfdata.serialized_to_parsed(dataset, feature_spec, None)
def test_pad_image_tensor_to_spec_shape(self): varlen_spec = utils.ExtendedTensorSpec(shape=(3, 2, 2, 1), dtype=tf.uint8, name='varlen', data_format='png', varlen_default_value=3.0) test_data = [[ [[[1]] * 2] * 2, [[[2]] * 2] * 2, ]] prepadded_tensor = tf.convert_to_tensor(test_data, dtype=varlen_spec.dtype) tensor = utils.pad_or_clip_tensor_to_spec_shape( prepadded_tensor, varlen_spec) with self.session() as sess: np_tensor = sess.run(tensor) self.assertAllEqual( np_tensor, np.array([[ [[[1]] * 2] * 2, [[[2]] * 2] * 2, [[[3]] * 2] * 2, ]]))
def get_action_specification(self): close_gripper_spec = tensorspec_utils.ExtendedTensorSpec( shape=(1, ), dtype=tf.float32, name='close_gripper') open_gripper_spec = tensorspec_utils.ExtendedTensorSpec( shape=(1, ), dtype=tf.float32, name='open_gripper') terminate_episode_spec = tensorspec_utils.ExtendedTensorSpec( shape=(1, ), dtype=tf.float32, name='terminate_episode') gripper_closed_spec = tensorspec_utils.ExtendedTensorSpec( shape=(1, ), dtype=tf.float32, name='gripper_closed') world_vector_spec = tensorspec_utils.ExtendedTensorSpec( shape=(3), dtype=tf.float32, name='world_vector') vertical_rotation_spec = tensorspec_utils.ExtendedTensorSpec( shape=(2), dtype=tf.float32, name='vertical_rotation') height_to_bottom_spec = tensorspec_utils.ExtendedTensorSpec( shape=(1, ), dtype=tf.float32, name='height_to_bottom') return tensorspec_utils.TensorSpecStruct( world_vector=world_vector_spec, vertical_rotation=vertical_rotation_spec, close_gripper=close_gripper_spec, open_gripper=open_gripper_spec, terminate_episode=terminate_episode_spec, gripper_closed=gripper_closed_spec, height_to_bottom=height_to_bottom_spec)