def getExpectedConfig(self, op_type): expected = example_parser_configuration_pb2.ExampleParserConfiguration( ) if op_type == 'ParseExampleV2': text_format.Parse(EXPECTED_CONFIG_V2, expected) else: text_format.Parse(EXPECTED_CONFIG_V1, expected) return expected
def testBasic(self): golden_config = example_parser_configuration_pb2.ExampleParserConfiguration( ) text_format.Parse(BASIC_PROTO, golden_config) with session.Session() as sess: examples = array_ops.placeholder(dtypes.string, shape=[1]) feature_to_type = { 'x': parsing_ops.FixedLenFeature([1], dtypes.float32, 33.0), 'y': parsing_ops.VarLenFeature(dtypes.string) } _ = parsing_ops.parse_example(examples, feature_to_type) parse_example_op = sess.graph.get_operation_by_name( 'ParseExample/ParseExample') config = extract_example_parser_configuration( parse_example_op, sess) self.assertProtoEquals(golden_config, config)
def extract_example_parser_configuration(parse_example_op, sess): """Returns an ExampleParserConfig proto. Args: parse_example_op: A ParseExample `Operation` sess: A tf.Session needed to obtain some configuration values. Returns: A ExampleParserConfig proto. Raises: ValueError: If attributes are inconsistent. """ config = example_parser_configuration_pb2.ExampleParserConfiguration() num_sparse = parse_example_op.get_attr("Nsparse") num_dense = parse_example_op.get_attr("Ndense") total_features = num_dense + num_sparse sparse_types = parse_example_op.get_attr("sparse_types") dense_types = parse_example_op.get_attr("Tdense") dense_shapes = parse_example_op.get_attr("dense_shapes") if len(sparse_types) != num_sparse: raise ValueError("len(sparse_types) attribute does not match " "Nsparse attribute (%d vs %d)" % (len(sparse_types), num_sparse)) if len(dense_types) != num_dense: raise ValueError("len(dense_types) attribute does not match " "Ndense attribute (%d vs %d)" % (len(dense_types), num_dense)) if len(dense_shapes) != num_dense: raise ValueError("len(dense_shapes) attribute does not match " "Ndense attribute (%d vs %d)" % (len(dense_shapes), num_dense)) # Skip over the serialized input, and the names input. fetch_list = parse_example_op.inputs[2:] # Fetch total_features key names and num_dense default values. if len(fetch_list) != (total_features + num_dense): raise ValueError( "len(fetch_list) does not match total features + num_dense" "(%d vs %d" % (len(fetch_list), (total_features + num_dense))) fetched = sess.run(fetch_list) if len(fetched) != len(fetch_list): raise ValueError("len(fetched) does not match len(fetch_list)" "(%d vs %d" % (len(fetched), len(fetch_list))) # Fetch indices. sparse_keys_start = 0 dense_keys_start = sparse_keys_start + num_sparse dense_def_start = dense_keys_start + num_dense # Output tensor indices. sparse_indices_start = 0 sparse_values_start = num_sparse sparse_shapes_start = sparse_values_start + num_sparse dense_values_start = sparse_shapes_start + num_sparse # Dense features. for i in range(num_dense): key = fetched[dense_keys_start + i] feature_config = config.feature_map[key] # Convert the default value numpy array fetched from the session run # into a TensorProto. fixed_config = feature_config.fixed_len_feature fixed_config.default_value.CopyFrom( tensor_util.make_tensor_proto(fetched[dense_def_start + i])) # Convert the shape from the attributes # into a TensorShapeProto. fixed_config.shape.CopyFrom( tensor_shape.TensorShape(dense_shapes[i]).as_proto()) fixed_config.dtype = dense_types[i] # Get the output tensor name. fixed_config.values_output_tensor_name = parse_example_op.outputs[ dense_values_start + i].name # Sparse features. for i in range(num_sparse): key = fetched[sparse_keys_start + i] feature_config = config.feature_map[key] var_len_feature = feature_config.var_len_feature var_len_feature.dtype = sparse_types[i] var_len_feature.indices_output_tensor_name = parse_example_op.outputs[ sparse_indices_start + i].name var_len_feature.values_output_tensor_name = parse_example_op.outputs[ sparse_values_start + i].name var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[ sparse_shapes_start + i].name return config
def _extract_from_parse_example_v2(parse_example_op, sess): """Extract ExampleParserConfig from ParseExampleV2 op.""" config = example_parser_configuration_pb2.ExampleParserConfiguration() dense_types = parse_example_op.get_attr("Tdense") num_sparse = parse_example_op.get_attr("num_sparse") sparse_types = parse_example_op.get_attr("sparse_types") ragged_value_types = parse_example_op.get_attr("ragged_value_types") ragged_split_types = parse_example_op.get_attr("ragged_split_types") dense_shapes = parse_example_op.get_attr("dense_shapes") num_dense = len(dense_types) num_ragged = len(ragged_value_types) assert len(ragged_value_types) == len(ragged_split_types) assert len(parse_example_op.inputs) == 5 + num_dense # Skip over the serialized input, and the names input. fetched = sess.run(parse_example_op.inputs[2:]) sparse_keys = fetched[0].tolist() dense_keys = fetched[1].tolist() ragged_keys = fetched[2].tolist() dense_defaults = fetched[3:] assert len(sparse_keys) == num_sparse assert len(dense_keys) == num_dense assert len(ragged_keys) == num_ragged # Output tensor indices. sparse_indices_start = 0 sparse_values_start = num_sparse sparse_shapes_start = sparse_values_start + num_sparse dense_values_start = sparse_shapes_start + num_sparse ragged_values_start = dense_values_start + num_dense ragged_row_splits_start = ragged_values_start + num_ragged # Dense features. for i in range(num_dense): key = dense_keys[i] feature_config = config.feature_map[key] # Convert the default value numpy array fetched from the session run # into a TensorProto. fixed_config = feature_config.fixed_len_feature fixed_config.default_value.CopyFrom( tensor_util.make_tensor_proto(dense_defaults[i])) # Convert the shape from the attributes # into a TensorShapeProto. fixed_config.shape.CopyFrom( tensor_shape.TensorShape(dense_shapes[i]).as_proto()) fixed_config.dtype = dense_types[i].as_datatype_enum # Get the output tensor name. fixed_config.values_output_tensor_name = parse_example_op.outputs[ dense_values_start + i].name # Sparse features. for i in range(num_sparse): key = sparse_keys[i] feature_config = config.feature_map[key] var_len_feature = feature_config.var_len_feature var_len_feature.dtype = sparse_types[i].as_datatype_enum var_len_feature.indices_output_tensor_name = parse_example_op.outputs[ sparse_indices_start + i].name var_len_feature.values_output_tensor_name = parse_example_op.outputs[ sparse_values_start + i].name var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[ sparse_shapes_start + i].name if num_ragged != 0: del ragged_values_start # unused del ragged_row_splits_start # unused raise ValueError("Ragged features are not yet supported by " "example_parser_configuration.proto") return config