def test_broken_ragged_tensors_no_check(self): """Make sure that it doesn't crash. The result is undefined.""" expression = prensor_test_util.create_broken_prensor() ragged_tensor_map = prensor._get_ragged_tensors( expression, calculate_options.get_options_with_minimal_checks()) string_tensor_map = {str(k): v for k, v in ragged_tensor_map.items()} self.evaluate(string_tensor_map)
def parse_elwc_with_struct2tensor( records: tf.Tensor, context_features: List[Feature], example_features: List[Feature], size_feature_name: Optional[str] = None) -> Dict[str, tf.RaggedTensor]: """Parses a batch of ELWC records into RaggedTensors using struct2tensor. Args: records: A dictionary with a single item. The value of this single item is the serialized ELWC input. context_features: List of context-level features. example_features: List of example-level features. size_feature_name: A string, the name of a feature for example list sizes. If None, which is default, this feature is not generated. Otherwise the feature is added to the feature dict. Returns: A dict that maps feature name to RaggedTensors. """ def get_step_name(feature_name: str): """Gets the name of the step (a component in a prensor Path) for a feature. A prensor step cannot contain dots ("."), but a feature name can. Args: feature_name: name of the feature Returns: a valid step name. """ return feature_name.replace('.', '_dot_') def get_default_filled_step_name(feature_name: str): return get_step_name(feature_name) + _DEFAULT_VALUE_SUFFIX def get_context_feature_path(feature: Feature): list_name = _TYPE_LIST_MAP.get(feature.dtype) return path.Path([ 'context', 'features', 'feature[{}]'.format(feature.name), list_name, 'value' ]) def get_example_feature_path(feature: Feature): list_name = _TYPE_LIST_MAP.get(feature.dtype) return path.Path([ 'examples', 'features', 'feature[{}]'.format(feature.name), list_name, 'value' ]) def get_promote_and_project_maps(features: List[Feature], is_context: bool): promote_map = {} project_map = {} if is_context: get_feature_path = get_context_feature_path get_promote_destination = lambda leaf_name: path.Path([leaf_name]) else: get_feature_path = get_example_feature_path get_promote_destination = lambda leaf_name: path.Path( # pylint: disable=g-long-lambda ['examples', leaf_name]) for feature in features: promote_map[get_step_name( feature.name)] = get_feature_path(feature) leaf_name = (get_step_name(feature.name) if feature.default_value is None else get_default_filled_step_name(feature.name)) project_map[feature.name] = get_promote_destination(leaf_name) return promote_map, project_map def get_pad_2d_ragged_fn(feature: Feature): def pad_2d_ragged(rt): dense = rt.to_tensor(shape=[None, feature.length], default_value=feature.default_value) flattened = tf.reshape(dense, [-1]) return tf.RaggedTensor.from_uniform_row_length(flattened, feature.length, validate=False) return pad_2d_ragged context_promote_map, context_keys_to_promoted_paths = ( get_promote_and_project_maps(context_features, is_context=True)) examples_promote_map, examples_keys_to_promoted_paths = ( get_promote_and_project_maps(example_features, is_context=False)) # Build the struct2tensor query. s2t_expr = (proto_expr.create_expression_from_proto( records, input_pb2.ExampleListWithContext.DESCRIPTOR).promote_and_broadcast( context_promote_map, path.Path([])).promote_and_broadcast(examples_promote_map, path.Path(['examples']))) # Pad features that have default_values specified. for features, parent_path in [(context_features, path.Path([])), (example_features, path.Path(['examples']))]: for feature in features: if feature.default_value is not None: s2t_expr = s2t_expr.map_ragged_tensors( parent_path=parent_path, source_fields=[get_step_name(feature.name)], operator=get_pad_2d_ragged_fn(feature), is_repeated=True, dtype=feature.dtype, new_field_name=get_default_filled_step_name(feature.name)) to_project = list( itertools.chain(context_keys_to_promoted_paths.values(), examples_keys_to_promoted_paths.values())) if size_feature_name is not None: s2t_expr = s2t_expr.create_size_field(path.Path(['examples']), get_step_name(size_feature_name)) to_project.append(path.Path([get_step_name(size_feature_name)])) projection = s2t_expr.project(to_project) options = calculate_options.get_options_with_minimal_checks() prensor_result = calculate.calculate_prensors([projection], options)[0] # a map from path.Path to RaggedTensors. projected_with_paths = prensor_util.get_ragged_tensors( prensor_result, options) context_dict = { f: projected_with_paths[context_keys_to_promoted_paths[f]] for f in context_keys_to_promoted_paths } examples_dict = { f: projected_with_paths[examples_keys_to_promoted_paths[f]] for f in examples_keys_to_promoted_paths } result = {} result.update(context_dict) result.update(examples_dict) if size_feature_name is not None: result[size_feature_name] = projected_with_paths[path.Path( [get_step_name(size_feature_name)])] return result
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for struct2tensor.prensor.""" from struct2tensor import calculate_options from struct2tensor import path from struct2tensor import prensor from struct2tensor.test import prensor_test_util import tensorflow as tf from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import _OPTIONS_TO_TEST = [ calculate_options.get_default_options(), calculate_options.get_options_with_minimal_checks() ] @test_util.run_all_in_graph_and_eager_modes class PrensorTest(tf.test.TestCase): def _assert_prensor_equals(self, lhs, rhs): if isinstance(lhs.node, prensor.RootNodeTensor): self.assertIsInstance(rhs.node, prensor.RootNodeTensor) self.assertIs(lhs.node.size, rhs.node.size) elif isinstance(lhs.node, prensor.ChildNodeTensor): self.assertIsInstance(rhs.node, prensor.ChildNodeTensor) self.assertIs(lhs.node.parent_index, rhs.node.parent_index) self.assertEqual(lhs.node.is_repeated, rhs.node.is_repeated) else: self.assertIsInstance(rhs.node, prensor.LeafNodeTensor)
def _test_assert_raises(self, test_runner): with self.assertRaises(tf.errors.InvalidArgumentError): test_runner(calculate_options.get_default_options()) test_runner(calculate_options.get_options_with_minimal_checks())