Exemplo n.º 1
0
 def test_broadcast_and_calculate(self):
     """Tests get_sparse_tensors on a deep tree."""
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_big_prensor())
     new_root, new_path = broadcast.broadcast_anonymous(
         expr, path.Path(["foo"]), "user")
     new_field = new_root.get_descendant_or_error(new_path)
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     self.assertAllEqual(leaf_node.parent_index, [0, 1, 2, 3])
     self.assertAllEqual(leaf_node.values, [9, 8, 8, 7])
def promote_and_broadcast_anonymous(root, origin, new_parent):
    """Promotes then broadcasts the origin until its parent is new_parent."""
    least_common_ancestor = origin.get_least_common_ancestor(new_parent)

    new_expr, new_path = root, origin
    while new_path.get_parent() != least_common_ancestor:
        new_expr, new_path = promote.promote_anonymous(new_expr, new_path)

    while new_path.get_parent() != new_parent:
        new_parent_step = new_parent.field_list[len(new_path) - 1]
        new_expr, new_path = broadcast.broadcast_anonymous(
            new_expr, new_path, new_parent_step)

    return new_expr, new_path
Exemplo n.º 3
0
  def test_broadcast_anonymous(self):
    expr = create_expression.create_expression_from_prensor(
        prensor_test_util.create_big_prensor())
    new_root, p = broadcast.broadcast_anonymous(expr, path.Path(["foo"]),
                                                "user")
    new_field = new_root.get_descendant_or_error(p)
    self.assertFalse(new_field.is_repeated)
    self.assertEqual(new_field.type, tf.int32)
    self.assertTrue(new_field.is_leaf)
    self.assertTrue(new_field.calculation_equal(new_field))
    self.assertFalse(new_field.calculation_equal(expr))
    leaf_node = expression_test_util.calculate_value_slowly(new_field)
    self.assertEqual(leaf_node.values.dtype, tf.int32)
    self.assertEqual(new_field.known_field_names(), frozenset())

    sources = new_field.get_source_expressions()
    self.assertLen(sources, 2)
    self.assertIs(expr.get_child("foo"), sources[0])
    self.assertIs(expr.get_child("user"), sources[1])