def test_prensor_children_ordered(self): def _recursively_check_sorted(p): self.assertEqual(list(p.get_children().keys()), sorted(p.get_children().keys())) for c in p.get_children().values(): _recursively_check_sorted(c) for pren in [ prensor_test_util.create_nested_prensor(), prensor_test_util.create_big_prensor(), prensor_test_util.create_deep_prensor() ]: _recursively_check_sorted(pren) p = prensor.create_prensor_from_descendant_nodes({ path.Path([]): prensor_test_util.create_root_node(1), path.Path(["d"]): prensor_test_util.create_optional_leaf_node([0], [True]), path.Path(["c"]): prensor_test_util.create_optional_leaf_node([0], [True]), path.Path(["b"]): prensor_test_util.create_optional_leaf_node([0], [True]), path.Path(["a"]): prensor_test_util.create_optional_leaf_node([0], [True]), }) self.assertEqual(["a", "b", "c", "d"], list(p.get_children().keys()))
def _create_nested_prensor_2(): r"""Creates a prensor representing a list of nested protocol buffers. keep_me no longer has a value in doc0. -----*---------------------------------------------------- / \ \ root0 root1----------------------- root2 (empty) / \ / \ \ \ | keep_my_sib0:False | keep_my_sib1:True | keep_my_sib2:False doc0 doc1--------------- doc2-------- | \ \ \ \ bar:"a" bar:"b" bar:"c" keep_me:True bar:"d" Returns: a prensor expression representing: {doc:[{bar:["a"], keep_me:False}], keep_my_sib:False} {doc:[{bar:["b","c"], keeo}, {bar:["d"]}], keep_me:True} {} """ return prensor.create_prensor_from_descendant_nodes({ path.Path([]): prensor_test_util.create_root_node(3), path.Path(["doc"]): prensor_test_util.create_child_node([0, 1, 1], True), path.Path(["keep_my_sib"]): prensor_test_util.create_repeated_leaf_node([0, 1, 1], [False, True, False]), path.Path(["doc", "bar"]): prensor_test_util.create_repeated_leaf_node([0, 1, 1, 2], ["a", "b", "c", "d"]), path.Path(["doc", "keep_me"]): prensor_test_util.create_optional_leaf_node([1], [True]) })
def test_map_many_values(self): with self.session(use_gpu=False) as sess: expr = create_expression.create_expression_from_prensor( prensor.create_prensor_from_descendant_nodes({ path.Path([]): prensor_test_util.create_root_node(3), path.Path(["foo"]): prensor_test_util.create_optional_leaf_node([0, 2, 3], [9, 8, 7]), path.Path(["bar"]): prensor_test_util.create_optional_leaf_node([0, 2, 3], [10, 20, 30]) })) new_root, p = map_values.map_many_values(expr, path.Path([]), ["foo", "bar"], lambda x, y: x + y, tf.int64, "new_field") leaf_node = expression_test_util.calculate_value_slowly( new_root.get_descendant_or_error(p)) [parent_index, values] = sess.run([leaf_node.parent_index, leaf_node.values]) self.assertAllEqual(parent_index, [0, 2, 3]) self.assertAllEqual(values, [19, 28, 37])
def test_children_order(self): # Different evaluations of the same prensor object should result in the same # prensor value objects. if tf.executing_eagerly(): return def _check_children(pv): self.assertEqual(sorted(pv.get_children().keys()), list(pv.get_children().keys())) for child in pv.get_children().values(): _check_children(child) with self.cached_session(use_gpu=False) as sess: p = prensor_test_util.create_nested_prensor() _check_children(sess.run(p)) with self.cached_session(use_gpu=False) as sess: p = prensor.create_prensor_from_descendant_nodes({ path.Path([]): prensor_test_util.create_root_node(1), path.Path(["d"]): prensor_test_util.create_optional_leaf_node([0], [True]), path.Path(["c"]): prensor_test_util.create_optional_leaf_node([0], [True]), path.Path(["b"]): prensor_test_util.create_optional_leaf_node([0], [True]), path.Path(["a"]): prensor_test_util.create_optional_leaf_node([0], [True]), }) pv = sess.run(p) self.assertEqual(["a", "b", "c", "d"], list(pv.get_children().keys()))
def _create_one_value_prensor(): """Creates a prensor expression representing a list of flat protocol buffers. Returns: a RootPrensor representing: {} {foo:8} {} """ return prensor.create_prensor_from_descendant_nodes({ path.Path([]): prensor_test_util.create_root_node(3), path.Path(["foo"]): prensor_test_util.create_optional_leaf_node([1], [8]) })