def test_broadcast_and_calculate(self): """Tests get_sparse_tensors on a deep tree.""" expr = create_expression.create_expression_from_prensor( prensor_test_util.create_big_prensor()) new_root, new_path = broadcast.broadcast_anonymous( expr, path.Path(["foo"]), "user") new_field = new_root.get_descendant_or_error(new_path) leaf_node = expression_test_util.calculate_value_slowly(new_field) self.assertAllEqual(leaf_node.parent_index, [0, 1, 2, 3]) self.assertAllEqual(leaf_node.values, [9, 8, 8, 7])
def promote_and_broadcast_anonymous(root, origin, new_parent): """Promotes then broadcasts the origin until its parent is new_parent.""" least_common_ancestor = origin.get_least_common_ancestor(new_parent) new_expr, new_path = root, origin while new_path.get_parent() != least_common_ancestor: new_expr, new_path = promote.promote_anonymous(new_expr, new_path) while new_path.get_parent() != new_parent: new_parent_step = new_parent.field_list[len(new_path) - 1] new_expr, new_path = broadcast.broadcast_anonymous( new_expr, new_path, new_parent_step) return new_expr, new_path
def test_broadcast_anonymous(self): expr = create_expression.create_expression_from_prensor( prensor_test_util.create_big_prensor()) new_root, p = broadcast.broadcast_anonymous(expr, path.Path(["foo"]), "user") new_field = new_root.get_descendant_or_error(p) self.assertFalse(new_field.is_repeated) self.assertEqual(new_field.type, tf.int32) self.assertTrue(new_field.is_leaf) self.assertTrue(new_field.calculation_equal(new_field)) self.assertFalse(new_field.calculation_equal(expr)) leaf_node = expression_test_util.calculate_value_slowly(new_field) self.assertEqual(leaf_node.values.dtype, tf.int32) self.assertEqual(new_field.known_field_names(), frozenset()) sources = new_field.get_source_expressions() self.assertLen(sources, 2) self.assertIs(expr.get_child("foo"), sources[0]) self.assertIs(expr.get_child("user"), sources[1])