def testShapeEquals(self): t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) self.assertTrue(tensor_util.ShapeEquals(t, [2, 2])) self.assertTrue(tensor_util.ShapeEquals(t, (2, 2))) self.assertTrue( tensor_util.ShapeEquals(t, tensor_util.make_tensor_shape_proto([2, 2]))) self.assertFalse(tensor_util.ShapeEquals(t, [5, 3])) self.assertFalse(tensor_util.ShapeEquals(t, [1, 4])) self.assertFalse(tensor_util.ShapeEquals(t, [4]))
def testShapeEquals(self): t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) self.assertTrue(tensor_util.ShapeEquals(t, [2, 2])) self.assertTrue(tensor_util.ShapeEquals(t, (2, 2))) self.assertTrue( tensor_util.ShapeEquals( t, tensor_util.make_tensor_shape_proto([2, 2]))) self.assertFalse(tensor_util.ShapeEquals(t, [5, 3])) self.assertFalse(tensor_util.ShapeEquals(t, [1, 4])) self.assertFalse(tensor_util.ShapeEquals(t, [4]))
def testConvertFromProto(self): proto = tensor_util.make_tensor_shape_proto([]) self.assertEqual(tensor_shape.TensorShape([]), tensor_shape.TensorShape(proto)) self.assertEqual(tensor_shape.TensorShape([]), tensor_shape.as_shape(proto)) proto = tensor_util.make_tensor_shape_proto([1, 37, 42]) self.assertEqual(tensor_shape.TensorShape([1, 37, 42]), tensor_shape.TensorShape(proto)) self.assertEqual(tensor_shape.TensorShape([1, 37, 42]), tensor_shape.as_shape(proto)) partial_proto_shape = tensor_shape.as_shape( tensor_util.make_tensor_shape_proto([-1, 37, 42])) partial_shape = tensor_shape.TensorShape([None, 37, 42]) self.assertNotEqual(partial_proto_shape, partial_shape) self.assertEqual(partial_proto_shape[0].value, None) self.assertEqual(partial_proto_shape[1].value, 37) self.assertEqual(partial_proto_shape[2].value, 42) self.assertTrue(partial_shape.is_compatible_with(partial_proto_shape))
def _parse_single_sequence_example_raw(serialized, context_sparse_keys=None, context_sparse_types=None, context_dense_keys=None, context_dense_types=None, context_dense_defaults=None, context_dense_shapes=None, feature_list_sparse_keys=None, feature_list_sparse_types=None, feature_list_dense_keys=None, feature_list_dense_types=None, feature_list_dense_shapes=None, feature_list_dense_defaults=None, debug_name=None, name=None): """Parses a single `SequenceExample` proto. Args: serialized: A scalar (0-D Tensor) of type string, a single binary serialized `SequenceExample` proto. context_sparse_keys: A list of string keys in the `SequenceExample`'s features. The results for these keys will be returned as `SparseTensor` objects. context_sparse_types: A list of `DTypes`, the same length as `sparse_keys`. Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` (`BytesList`) are supported. context_dense_keys: A list of string keys in the examples' features. The results for these keys will be returned as `Tensor`s context_dense_types: A list of DTypes, same length as `context_dense_keys`. Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` (`BytesList`) are supported. context_dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the dict must match the context_dense_keys of the feature. context_dense_shapes: A list of tuples, same length as `context_dense_keys`. The shape of the data for each context_dense feature referenced by `context_dense_keys`. Required for any input tensors identified by `context_dense_keys` whose shapes are anything other than `[]` or `[1]`. feature_list_sparse_keys: A list of string keys in the `SequenceExample`'s feature_lists. The results for these keys will be returned as `SparseTensor` objects. feature_list_sparse_types: A list of `DTypes`, same length as `sparse_keys`. Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` (`BytesList`) are supported. feature_list_dense_keys: A list of string keys in the `SequenceExample`'s features_lists. The results for these keys will be returned as `Tensor`s. feature_list_dense_types: A list of `DTypes`, same length as `feature_list_dense_keys`. Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` (`BytesList`) are supported. feature_list_dense_shapes: A list of tuples, same length as `feature_list_dense_keys`. The shape of the data for each `FeatureList` feature referenced by `feature_list_dense_keys`. feature_list_dense_defaults: A dict mapping key strings to values. The only currently allowed value is `None`. Any key appearing in this dict with value `None` is allowed to be missing from the `SequenceExample`. If missing, the key is treated as zero-length. debug_name: A scalar (0-D Tensor) of strings (optional), the name of the serialized proto. name: A name for this operation (optional). Returns: A tuple of two `dict`s, each mapping keys to `Tensor`s and `SparseTensor`s. The first dict contains the context key/values. The second dict contains the feature_list key/values. Raises: ValueError: If context_sparse and context_dense key sets intersect, if input lengths do not match up, or if a value in feature_list_dense_defaults is not None. TypeError: if feature_list_dense_defaults is not either None or a dict. """ with ops.op_scope([serialized], name, "ParseSingleSequenceExample"): context_dense_defaults = ( {} if context_dense_defaults is None else context_dense_defaults) context_sparse_keys = ( [] if context_sparse_keys is None else context_sparse_keys) context_sparse_types = ( [] if context_sparse_types is None else context_sparse_types) context_dense_keys = ( [] if context_dense_keys is None else context_dense_keys) context_dense_types = ( [] if context_dense_types is None else context_dense_types) context_dense_shapes = ( [[]] * len(context_dense_keys) if context_dense_shapes is None else context_dense_shapes) feature_list_sparse_keys = ( [] if feature_list_sparse_keys is None else feature_list_sparse_keys) feature_list_sparse_types = ( [] if feature_list_sparse_types is None else feature_list_sparse_types) feature_list_dense_keys = ( [] if feature_list_dense_keys is None else feature_list_dense_keys) feature_list_dense_types = ( [] if feature_list_dense_types is None else feature_list_dense_types) feature_list_dense_shapes = ( [[]] * len(feature_list_dense_keys) if feature_list_dense_shapes is None else feature_list_dense_shapes) feature_list_dense_defaults = ( dict() if feature_list_dense_defaults is None else feature_list_dense_defaults) debug_name = "" if debug_name is None else debug_name # Internal feature_list_dense_missing_assumed_empty = [] num_context_dense = len(context_dense_keys) num_feature_list_dense = len(feature_list_dense_keys) num_context_sparse = len(context_sparse_keys) num_feature_list_sparse = len(feature_list_sparse_keys) if len(context_dense_shapes) != num_context_dense: raise ValueError( "len(context_dense_shapes) != len(context_dense_keys): %d vs. %d" % (len(context_dense_shapes), num_context_dense)) if len(context_dense_types) != num_context_dense: raise ValueError( "len(context_dense_types) != len(num_context_dense): %d vs. %d" % (len(context_dense_types), num_context_dense)) if len(feature_list_dense_shapes) != num_feature_list_dense: raise ValueError( "len(feature_list_dense_shapes) != len(feature_list_dense_keys): " "%d vs. %d" % (len(feature_list_dense_shapes), num_feature_list_dense)) if len(feature_list_dense_types) != num_feature_list_dense: raise ValueError( "len(feature_list_dense_types) != len(num_feature_list_dense):" "%d vs. %d" % (len(feature_list_dense_types), num_feature_list_dense)) if len(context_sparse_types) != num_context_sparse: raise ValueError( "len(context_sparse_types) != len(context_sparse_keys): %d vs. %d" % (len(context_sparse_types), num_context_sparse)) if len(feature_list_sparse_types) != num_feature_list_sparse: raise ValueError( "len(feature_list_sparse_types) != len(feature_list_sparse_keys): " "%d vs. %d" % (len(feature_list_sparse_types), num_feature_list_sparse)) if (num_context_dense + num_context_sparse + num_feature_list_dense + num_feature_list_sparse) == 0: raise ValueError( "Must provide at least one context_sparse key, context_dense key, " ", feature_list_sparse key, or feature_list_dense key") if not set(context_dense_keys).isdisjoint(set(context_sparse_keys)): raise ValueError( "context_dense and context_sparse keys must not intersect; " "intersection: %s" % set(context_dense_keys).intersection(set(context_sparse_keys))) if not set(feature_list_dense_keys).isdisjoint( set(feature_list_sparse_keys)): raise ValueError( "feature_list_dense and feature_list_sparse keys must not intersect; " "intersection: %s" % set(feature_list_dense_keys).intersection( set(feature_list_sparse_keys))) if not isinstance(feature_list_dense_defaults, dict): raise TypeError("feature_list_dense_defaults must be a dict") for k, v in feature_list_dense_defaults.items(): if v is not None: raise ValueError("Value feature_list_dense_defaults[%s] must be None" % k) feature_list_dense_missing_assumed_empty.append(k) context_dense_defaults_vec = [] for i, key in enumerate(context_dense_keys): default_value = context_dense_defaults.get(key) if default_value is None: default_value = constant_op.constant([], dtype=context_dense_types[i]) elif not isinstance(default_value, ops.Tensor): key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) default_value = ops.convert_to_tensor( default_value, dtype=context_dense_types[i], name=key_name) default_value = array_ops.reshape( default_value, context_dense_shapes[i]) context_dense_defaults_vec.append(default_value) context_dense_shapes = [tensor_util.make_tensor_shape_proto(shape) if isinstance(shape, (list, tuple)) else shape for shape in context_dense_shapes] feature_list_dense_shapes = [tensor_util.make_tensor_shape_proto(shape) if isinstance(shape, (list, tuple)) else shape for shape in feature_list_dense_shapes] # pylint: disable=protected-access outputs = gen_parsing_ops._parse_single_sequence_example( serialized=serialized, debug_name=debug_name, context_dense_defaults=context_dense_defaults_vec, context_sparse_keys=context_sparse_keys, context_sparse_types=context_sparse_types, context_dense_keys=context_dense_keys, context_dense_shapes=context_dense_shapes, feature_list_sparse_keys=feature_list_sparse_keys, feature_list_sparse_types=feature_list_sparse_types, feature_list_dense_keys=feature_list_dense_keys, feature_list_dense_types=feature_list_dense_types, feature_list_dense_shapes=feature_list_dense_shapes, feature_list_dense_missing_assumed_empty=( feature_list_dense_missing_assumed_empty), name=name) # pylint: enable=protected-access (context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values) = outputs context_sparse_tensors = [ ops.SparseTensor(ix, val, shape) for (ix, val, shape) in zip(context_sparse_indices, context_sparse_values, context_sparse_shapes)] feature_list_sparse_tensors = [ ops.SparseTensor(ix, val, shape) for (ix, val, shape) in zip(feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes)] context_output = dict( zip(context_sparse_keys + context_dense_keys, context_sparse_tensors + context_dense_values)) feature_list_output = dict( zip(feature_list_sparse_keys + feature_list_dense_keys, feature_list_sparse_tensors + feature_list_dense_values)) return (context_output, feature_list_output)
def _parse_example_raw(serialized, names=None, sparse_keys=None, sparse_types=None, dense_keys=None, dense_types=None, dense_defaults=None, dense_shapes=None, name=None): """Parses `Example` protos. Args: serialized: A vector (1-D Tensor) of strings, a batch of binary serialized `Example` protos. names: A vector (1-D Tensor) of strings (optional), the names of the serialized protos. sparse_keys: A list of string keys in the examples' features. The results for these keys will be returned as `SparseTensor` objects. sparse_types: A list of `DTypes` of the same length as `sparse_keys`. Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` (`BytesList`) are supported. dense_keys: A list of string keys in the examples' features. The results for these keys will be returned as `Tensor`s dense_types: A list of DTypes of the same length as `dense_keys`. Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string` (`BytesList`) are supported. dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the dict must match the dense_keys of the feature. dense_shapes: A list of tuples with the same length as `dense_keys`. The shape of the data for each dense feature referenced by `dense_keys`. Required for any input tensors identified by `dense_keys` whose shapes are anything other than `[]` or `[1]`. name: A name for this operation (optional). Returns: A `dict` mapping keys to `Tensor`s and `SparseTensor`s. Raises: ValueError: If sparse and dense key sets intersect, or input lengths do not match up. """ with ops.op_scope([serialized, names], name, "ParseExample"): names = [] if names is None else names dense_defaults = {} if dense_defaults is None else dense_defaults sparse_keys = [] if sparse_keys is None else sparse_keys sparse_types = [] if sparse_types is None else sparse_types dense_keys = [] if dense_keys is None else dense_keys dense_types = [] if dense_types is None else dense_types dense_shapes = ( [[]] * len(dense_keys) if dense_shapes is None else dense_shapes) num_dense = len(dense_keys) num_sparse = len(sparse_keys) if len(dense_shapes) != num_dense: raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d" % (len(dense_shapes), num_dense)) if len(dense_types) != num_dense: raise ValueError("len(dense_types) != len(num_dense): %d vs. %d" % (len(dense_types), num_dense)) if len(sparse_types) != num_sparse: raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d" % (len(sparse_types), num_sparse)) if num_dense + num_sparse == 0: raise ValueError("Must provide at least one sparse key or dense key") if not set(dense_keys).isdisjoint(set(sparse_keys)): raise ValueError( "Dense and sparse keys must not intersect; intersection: %s" % set(dense_keys).intersection(set(sparse_keys))) dense_defaults_vec = [] for i, key in enumerate(dense_keys): default_value = dense_defaults.get(key) if default_value is None: default_value = constant_op.constant([], dtype=dense_types[i]) elif not isinstance(default_value, ops.Tensor): key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) default_value = ops.convert_to_tensor( default_value, dtype=dense_types[i], name=key_name) default_value = array_ops.reshape(default_value, dense_shapes[i]) dense_defaults_vec.append(default_value) dense_shapes = [tensor_util.make_tensor_shape_proto(shape) if isinstance(shape, (list, tuple)) else shape for shape in dense_shapes] # pylint: disable=protected-access outputs = gen_parsing_ops._parse_example( serialized=serialized, names=names, dense_defaults=dense_defaults_vec, sparse_keys=sparse_keys, sparse_types=sparse_types, dense_keys=dense_keys, dense_shapes=dense_shapes, name=name) # pylint: enable=protected-access (sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs sparse_tensors = [ops.SparseTensor(ix, val, shape) for (ix, val, shape) in zip(sparse_indices, sparse_values, sparse_shapes)] return dict( zip(sparse_keys + dense_keys, sparse_tensors + dense_values))