def test_any_path(self): my_any_0 = test_any_pb2.MessageWithAny() my_value_0 = test_pb2.AllSimple() my_value_0.optional_int32 = 17 my_any_0.my_any.Pack(my_value_0) expr = proto.create_expression_from_proto( [my_any_0.SerializeToString()], test_any_pb2.MessageWithAny.DESCRIPTOR) new_root = promote.promote( expr, path.Path( ["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]), "new_int32") new_field = new_root.get_descendant_or_error( path.Path(["my_any", "new_int32"])) result = calculate_with_source_paths.calculate_prensors_with_source_paths( [new_field]) prensor_result, proto_summary_result = result self.assertLen(prensor_result, 1) self.assertLen(proto_summary_result, 1) leaf_node = prensor_result[0].node self.assertAllEqual(leaf_node.parent_index, [0]) self.assertAllEqual(leaf_node.values, [17]) list_of_paths = proto_summary_result[0].paths expected = [ path.Path( ["my_any", "(struct2tensor.test.AllSimple)", "optional_int32"]) ] self.equal_ignore_order(list_of_paths, expected)
def test_e2e_proto(self): """Integration test for parsing protobufs.""" serialized = tf.constant([ text_format.Merge( """ session_info { session_duration_sec: 1.0 session_feature: "foo" } event { query: "Hello" action { number_of_views: 1 } action { } } event { query: "world" action { number_of_views: 2 } action { number_of_views: 3 } } """, test_pb2.Session()).SerializeToString() ]) expr = proto.create_expression_from_proto( serialized, test_pb2.Session().DESCRIPTOR) [p] = calculate.calculate_prensors([expr]) print(p) st = prensor_to_structured_tensor.prensor_to_structured_tensor(p) print(st)
def _get_user_info_with_extension(): my_user_info = test_pb2.UserInfo() my_user_info.Extensions[ test_extension_pb2.MyExternalExtension.ext].special = "shhh" serialized = [my_user_info.SerializeToString()] return proto.create_expression_from_proto( serialized, test_any_pb2.MessageWithAny.DESCRIPTOR)
def _get_expression_with_any(): my_any_0 = test_any_pb2.MessageWithAny() my_value_0 = test_pb2.AllSimple() my_value_0.optional_int32 = 0 my_any_0.my_any.Pack(my_value_0) my_any_1 = test_any_pb2.MessageWithAny() my_value_1 = test_pb2.UserInfo() my_any_1.my_any.Pack(my_value_1) my_any_2 = test_any_pb2.MessageWithAny() my_value_2 = test_pb2.AllSimple() my_value_2.optional_int32 = 20 my_any_2.my_any.Pack(my_value_2) serialized = [x.SerializeToString() for x in [my_any_0, my_any_1, my_any_2]] return proto.create_expression_from_proto( serialized, test_any_pb2.MessageWithAny.DESCRIPTOR)
def text_to_expression(text_list, example_proto_clz): """Create an expression from a list of text format protos.""" return proto.create_expression_from_proto( text_to_tensor(text_list, example_proto_clz), example_proto_clz().DESCRIPTOR)
def parse_elwc_with_struct2tensor( records: tf.Tensor, context_features: List[Feature], example_features: List[Feature], size_feature_name: Optional[str] = None) -> Dict[str, tf.RaggedTensor]: """Parses a batch of ELWC records into RaggedTensors using struct2tensor. Args: records: A dictionary with a single item. The value of this single item is the serialized ELWC input. context_features: List of context-level features. example_features: List of example-level features. size_feature_name: A string, the name of a feature for example list sizes. If None, which is default, this feature is not generated. Otherwise the feature is added to the feature dict. Returns: A dict that maps feature name to RaggedTensors. """ def get_step_name(feature_name: str): """Gets the name of the step (a component in a prensor Path) for a feature. A prensor step cannot contain dots ("."), but a feature name can. Args: feature_name: name of the feature Returns: a valid step name. """ return feature_name.replace('.', '_dot_') def get_default_filled_step_name(feature_name: str): return get_step_name(feature_name) + _DEFAULT_VALUE_SUFFIX def get_context_feature_path(feature: Feature): list_name = _TYPE_LIST_MAP.get(feature.dtype) return path.Path([ 'context', 'features', 'feature[{}]'.format(feature.name), list_name, 'value' ]) def get_example_feature_path(feature: Feature): list_name = _TYPE_LIST_MAP.get(feature.dtype) return path.Path([ 'examples', 'features', 'feature[{}]'.format(feature.name), list_name, 'value' ]) def get_promote_and_project_maps(features: List[Feature], is_context: bool): promote_map = {} project_map = {} if is_context: get_feature_path = get_context_feature_path get_promote_destination = lambda leaf_name: path.Path([leaf_name]) else: get_feature_path = get_example_feature_path get_promote_destination = lambda leaf_name: path.Path( # pylint: disable=g-long-lambda ['examples', leaf_name]) for feature in features: promote_map[get_step_name( feature.name)] = get_feature_path(feature) leaf_name = (get_step_name(feature.name) if feature.default_value is None else get_default_filled_step_name(feature.name)) project_map[feature.name] = get_promote_destination(leaf_name) return promote_map, project_map def get_pad_2d_ragged_fn(feature: Feature): def pad_2d_ragged(rt): dense = rt.to_tensor(shape=[None, feature.length], default_value=feature.default_value) flattened = tf.reshape(dense, [-1]) return tf.RaggedTensor.from_uniform_row_length(flattened, feature.length, validate=False) return pad_2d_ragged context_promote_map, context_keys_to_promoted_paths = ( get_promote_and_project_maps(context_features, is_context=True)) examples_promote_map, examples_keys_to_promoted_paths = ( get_promote_and_project_maps(example_features, is_context=False)) # Build the struct2tensor query. s2t_expr = (proto_expr.create_expression_from_proto( records, input_pb2.ExampleListWithContext.DESCRIPTOR).promote_and_broadcast( context_promote_map, path.Path([])).promote_and_broadcast(examples_promote_map, path.Path(['examples']))) # Pad features that have default_values specified. for features, parent_path in [(context_features, path.Path([])), (example_features, path.Path(['examples']))]: for feature in features: if feature.default_value is not None: s2t_expr = s2t_expr.map_ragged_tensors( parent_path=parent_path, source_fields=[get_step_name(feature.name)], operator=get_pad_2d_ragged_fn(feature), is_repeated=True, dtype=feature.dtype, new_field_name=get_default_filled_step_name(feature.name)) to_project = list( itertools.chain(context_keys_to_promoted_paths.values(), examples_keys_to_promoted_paths.values())) if size_feature_name is not None: s2t_expr = s2t_expr.create_size_field(path.Path(['examples']), get_step_name(size_feature_name)) to_project.append(path.Path([get_step_name(size_feature_name)])) projection = s2t_expr.project(to_project) options = calculate_options.get_options_with_minimal_checks() prensor_result = calculate.calculate_prensors([projection], options)[0] # a map from path.Path to RaggedTensors. projected_with_paths = prensor_util.get_ragged_tensors( prensor_result, options) context_dict = { f: projected_with_paths[context_keys_to_promoted_paths[f]] for f in context_keys_to_promoted_paths } examples_dict = { f: projected_with_paths[examples_keys_to_promoted_paths[f]] for f in examples_keys_to_promoted_paths } result = {} result.update(context_dict) result.update(examples_dict) if size_feature_name is not None: result[size_feature_name] = projected_with_paths[path.Path( [get_step_name(size_feature_name)])] return result