Exemplo n.º 1
0
    def test_promote_and_calculate_substructure(self):
        """Tests promoting substructure on a tree with depth of 4."""
        expr = create_expression.create_expression_from_prensor(
            prensor_test_util.create_four_layer_prensor())
        new_root, new_path = promote.promote_anonymous(
            expr, path.Path(["event", "doc", "nested_child"]))
        new_nested_child = new_root.get_descendant_or_error(new_path)
        bar_expr = new_root.get_descendant_or_error(new_path.get_child("bar"))
        keep_me_expr = new_root.get_descendant_or_error(
            new_path.get_child("keep_me"))

        # the promoted nested_child's parent index is changed.
        nested_child_node = expression_test_util.calculate_value_slowly(
            new_nested_child)
        self.assertAllEqual(nested_child_node.parent_index, [0, 1, 1, 1])
        self.assertTrue(nested_child_node.is_repeated)

        # bar's parent index should be unchanged.
        bar_node = expression_test_util.calculate_value_slowly(bar_expr)
        self.assertAllEqual(bar_node.parent_index, [0, 1, 1, 2])
        self.assertAllEqual(bar_node.values, [b"a", b"b", b"c", b"d"])
        self.assertTrue(bar_node.is_repeated)

        # keep_me's parent index should be unchanged.
        keep_me_node = expression_test_util.calculate_value_slowly(
            keep_me_expr)
        self.assertAllEqual(keep_me_node.parent_index, [0, 1])
        self.assertAllEqual(keep_me_node.values, [False, True])
        self.assertFalse(keep_me_node.is_repeated)
Exemplo n.º 2
0
    def test_reroot_and_create_proto_index(self):
        expr = create_expression.create_expression_from_prensor(
            prensor_test_util.create_big_prensor()).reroot(
                "doc").create_proto_index("proto_index")
        proto_index = expr.get_child("proto_index")
        new_field = expr.get_child("bar")
        leaf_node = expression_test_util.calculate_value_slowly(new_field)
        proto_index_node = expression_test_util.calculate_value_slowly(
            proto_index)

        self.assertIsNotNone(new_field)
        self.assertTrue(new_field.is_repeated)
        self.assertEqual(new_field.type, tf.string)
        self.assertTrue(new_field.is_leaf)
        self.assertEqual(new_field.known_field_names(), frozenset())
        self.assertEqual(leaf_node.values.dtype, tf.string)

        self.assertIsNotNone(proto_index)
        self.assertFalse(proto_index.is_repeated)
        self.assertEqual(proto_index.type, tf.int64)
        self.assertTrue(proto_index.is_leaf)
        self.assertEqual(proto_index.known_field_names(), frozenset())

        self.assertEqual(proto_index_node.values.dtype, tf.int64)

        self.assertAllEqual([b"a", b"b", b"c", b"d"], leaf_node.values)
        self.assertAllEqual([0, 1, 1, 2], leaf_node.parent_index)
        self.assertAllEqual([0, 1, 1], proto_index_node.values)
        self.assertAllEqual([0, 1, 2], proto_index_node.parent_index)
Exemplo n.º 3
0
    def test_promote_substructure(self):
        """Tests promote.promote(...) of substructure."""
        expr = create_expression.create_expression_from_prensor(
            prensor_test_util.create_deep_prensor())
        new_root = promote.promote(expr, path.Path(["event", "doc"]),
                                   "new_field")

        new_field = new_root.get_child_or_error("new_field")
        self.assertIsNotNone(new_field)
        self.assertTrue(new_field.is_repeated)
        self.assertEqual(new_field.known_field_names(),
                         frozenset(["bar", "keep_me"]))

        bar_expr = new_field.get_child_or_error("bar")
        self.assertIsNotNone(bar_expr)
        self.assertTrue(bar_expr.is_repeated)
        self.assertEqual(bar_expr.type, tf.string)
        self.assertTrue(bar_expr.is_leaf)

        keep_me_expr = new_field.get_child_or_error("keep_me")
        self.assertIsNotNone(keep_me_expr)
        self.assertFalse(keep_me_expr.is_repeated)
        self.assertEqual(keep_me_expr.type, tf.bool)
        self.assertTrue(keep_me_expr.is_leaf)

        child_node = expression_test_util.calculate_value_slowly(new_field)
        self.assertEqual(child_node.size, 3)
        self.assertTrue(child_node.is_repeated)

        bar_node = expression_test_util.calculate_value_slowly(bar_expr)
        self.assertEqual(bar_node.values.dtype, tf.string)

        keep_me_node = expression_test_util.calculate_value_slowly(
            keep_me_expr)
        self.assertEqual(keep_me_node.values.dtype, tf.bool)
Exemplo n.º 4
0
  def test_broadcast_substructure(self):
    """Tests broadcast of a submessage.

    The result of broadcasting `user` into `doc` looks like:
    {
      foo: 9,
      foorepeated: [9],
      doc: [{bar:["a"], keep_me:False, new_user: [{friends:["a"]}]}],
      user: [{friends:["a"]}]
    },
    {
      foo: 8,
      foorepeated: [8, 7],
      doc: [
        {
          bar: ["b","c"],
          keep_me: True,
          new_user: [{friends:["b", "c"]},{friends:["d"]}]
        },
        {
          bar: ["d"],
          new_user: [{friends:["b", "c"]},{friends:["d"]}]
        }
      ],
      user: [{friends:["b", "c"]},{friends:["d"]}],
    },
    {
      foo: 7,
      foorepeated: [6],
      user: [{friends:["e"]}]
    }
    """
    expr = create_expression.create_expression_from_prensor(
        prensor_test_util.create_big_prensor())
    new_root = broadcast.broadcast(expr, path.Path(["user"]), "doc", "new_user")
    new_user = new_root.get_child("doc").get_child("new_user")
    self.assertIsNotNone(new_user)
    self.assertTrue(new_user.is_repeated)
    self.assertIsNone(new_user.type)
    self.assertFalse(new_user.is_leaf)

    new_user_node = expression_test_util.calculate_value_slowly(new_user)
    self.assertAllEqual(new_user_node.parent_index, [0, 1, 1, 2, 2])
    self.assertAllEqual(new_user_node.index_to_value, [0, 1, 2, 1, 2])

    new_friends = new_user.get_child("friends")
    self.assertIsNotNone(new_friends)
    self.assertTrue(new_friends.is_repeated)
    self.assertEqual(new_friends.type, tf.string)
    self.assertTrue(new_friends.is_leaf)

    new_friends_node = expression_test_util.calculate_value_slowly(new_friends)
    self.assertEqual(new_friends_node.values.dtype, tf.string)
    self.assertAllEqual(new_friends_node.values,
                        ["a", "b", "c", "d", "b", "c", "d"])
    self.assertAllEqual(new_friends_node.parent_index, [0, 1, 1, 2, 3, 3, 4])
Exemplo n.º 5
0
    def test_map_many_values(self):

        with self.session(use_gpu=False) as sess:
            expr = create_expression.create_expression_from_prensor(
                prensor.create_prensor_from_descendant_nodes({
                    path.Path([]):
                    prensor_test_util.create_root_node(3),
                    path.Path(["foo"]):
                    prensor_test_util.create_optional_leaf_node([0, 2, 3],
                                                                [9, 8, 7]),
                    path.Path(["bar"]):
                    prensor_test_util.create_optional_leaf_node([0, 2, 3],
                                                                [10, 20, 30])
                }))

            new_root, p = map_values.map_many_values(expr, path.Path([]),
                                                     ["foo", "bar"],
                                                     lambda x, y: x + y,
                                                     tf.int64, "new_field")

            leaf_node = expression_test_util.calculate_value_slowly(
                new_root.get_descendant_or_error(p))
            [parent_index,
             values] = sess.run([leaf_node.parent_index, leaf_node.values])

            self.assertAllEqual(parent_index, [0, 2, 3])
            self.assertAllEqual(values, [19, 28, 37])
Exemplo n.º 6
0
 def test_create_expression_from_proto_and_calculate_event_id_value(self):
     """Tests get_sparse_tensors on a deep tree."""
     expr = proto_test_util._get_expression_from_session_empty_user_info()
     event_id_value = expression_test_util.calculate_value_slowly(
         expr.get_descendant_or_error(path.Path(["event", "event_id"])))
     self.assertAllEqual(event_id_value.parent_index, [0, 1, 2, 4])
     self.assertAllEqual(event_id_value.values, [b"A", b"B", b"C", b"D"])
Exemplo n.º 7
0
 def test_create_expression_from_proto_and_calculate_root_value(self):
     """Tests get_sparse_tensors on a deep tree."""
     expr = proto_test_util._get_expression_from_session_empty_user_info()
     root_value = expression_test_util.calculate_value_slowly(expr)
     # For some reason, this fails on tf.eager. It could be because it is
     # a scalar, I don't know.
     self.assertEqual(self.evaluate(root_value.size), 2)
Exemplo n.º 8
0
    def test_create_proto_index_directly_reroot_at_action(self):
        sessions = [
            """
        event {
          action {}
          action {}
        }
        event {}
        event { action {} }
        """, "", """
        event {}
        event {
          action {}
          action {}
        }
        event {  }
        """
        ]
        expr = proto_test_util.text_to_expression(sessions, test_pb2.Session)
        reroot_expr = expr.reroot("event.action")
        # Reroot with a depth > 1 (all the other cases are depth == 1)
        proto_index_directly_reroot_at_action = (
            reroot_expr.create_proto_index(
                "proto_index_directly_reroot_at_action").get_child_or_error(
                    "proto_index_directly_reroot_at_action"))

        self.assertFalse(proto_index_directly_reroot_at_action.is_repeated)
        result = expression_test_util.calculate_value_slowly(
            proto_index_directly_reroot_at_action)
        self.assertAllEqual(result.parent_index, [0, 1, 2, 3, 4])
        self.assertAllEqual(result.values, [0, 0, 0, 2, 2])
Exemplo n.º 9
0
    def test_create_expression_from_proto_with_any(self):
        """Test an any field."""
        expr = _get_expression_with_any()
        any_expr = expr.get_child_or_error("my_any")
        simple_expr = expr.get_descendant_or_error(
            path.Path([
                "my_any", "(type.googleapis.com/struct2tensor.test.AllSimple)"
            ]))
        self.assertFalse(simple_expr.is_repeated)
        self.assertIsNone(simple_expr.type)
        self.assertFalse(simple_expr.is_leaf)
        self.assertFalse(simple_expr.calculation_is_identity())
        self.assertTrue(simple_expr.calculation_equal(simple_expr))
        self.assertFalse(simple_expr.calculation_equal(expr))
        child_node = expression_test_util.calculate_value_slowly(simple_expr)
        self.assertEqual(child_node.parent_index.dtype, tf.int64)
        self.assertEqual(
            simple_expr.known_field_names(),
            frozenset({
                "optional_string", "optional_uint64", "repeated_uint64",
                "repeated_int32", "repeated_string", "optional_int32",
                "optional_float", "repeated_int64", "optional_uint32",
                "repeated_float", "repeated_uint32", "optional_double",
                "optional_int64", "repeated_double"
            }))

        sources = simple_expr.get_source_expressions()
        self.assertLen(sources, 1)
        self.assertIs(any_expr, sources[0])
Exemplo n.º 10
0
 def test_create_expression_from_proto_and_calculate_root_value(self):
     """Tests get_sparse_tensors on a deep tree."""
     with self.session(use_gpu=False) as sess:
         expr = proto_test_util._get_expression_from_session_empty_user_info(
         )
         root_value = expression_test_util.calculate_value_slowly(expr)
         size = sess.run(root_value.size)
         self.assertEqual(size, 2)
Exemplo n.º 11
0
 def test_size_anonymous(self):
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_big_prensor())
     new_root, new_path = size.size_anonymous(expr,
                                              path.Path(["doc", "bar"]))
     new_field = new_root.get_descendant_or_error(new_path)
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     self.assertAllEqual(leaf_node.parent_index, [0, 1, 2])
     self.assertAllEqual(leaf_node.values, [1, 2, 1])
Exemplo n.º 12
0
 def test_create_has_field(self):
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_big_prensor())
     new_root = expr.create_has_field("doc.keep_me", "result")
     new_field = new_root.get_descendant_or_error(
         path.Path(["doc", "result"]))
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     self.assertAllEqual(leaf_node.parent_index, [0, 1, 2])
     self.assertAllEqual(leaf_node.values, [True, True, False])
Exemplo n.º 13
0
 def test_get_positional_index_calculate(self):
   expr = create_expression.create_expression_from_prensor(
       prensor_test_util.create_nested_prensor())
   new_root, new_path = index.get_positional_index(
       expr, path.Path(["user", "friends"]), path.get_anonymous_field())
   new_field = new_root.get_descendant_or_error(new_path)
   leaf_node = expression_test_util.calculate_value_slowly(new_field)
   self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3])
   self.assertAllEqual(leaf_node.values, [0, 0, 1, 0, 0])
Exemplo n.º 14
0
 def test_create_expression_from_proto_and_calculate_event_value(self):
     """Tests get_sparse_tensors on a deep tree."""
     with self.session(use_gpu=False) as sess:
         expr = proto_test_util._get_expression_from_session_empty_user_info(
         )
         event_value = expression_test_util.calculate_value_slowly(
             expr.get_child_or_error("event"))
         parent_index = sess.run(event_value.parent_index)
         self.assertAllEqual(parent_index, [0, 0, 0, 1, 1])
Exemplo n.º 15
0
 def test_size_missing_value(self):
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_big_prensor())
     new_root = size.size(expr, path.Path(["doc", "keep_me"]), "result")
     new_field = new_root.get_descendant_or_error(
         path.Path(["doc", "result"]))
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     self.assertAllEqual(leaf_node.parent_index, [0, 1, 2])
     self.assertAllEqual(leaf_node.values, [1, 1, 0])
Exemplo n.º 16
0
 def test_promote_and_calculate(self):
     """Tests promoting a leaf on a nested tree."""
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_nested_prensor())
     new_root, new_path = promote.promote_anonymous(
         expr, path.Path(["user", "friends"]))
     new_field = new_root.get_descendant_or_error(new_path)
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 1, 2])
     self.assertAllEqual(leaf_node.values, [b"a", b"b", b"c", b"d", b"e"])
Exemplo n.º 17
0
 def test_create_expression_from_proto_and_calculate_event_value(
     self, use_string_view):
   """Tests get_sparse_tensors on a deep tree."""
   expr = proto_test_util._get_expression_from_session_empty_user_info()
   event_value = expression_test_util.calculate_value_slowly(
       expr.get_child_or_error("event"),
       options=self._get_calculate_options(use_string_view))
   self.assertAllEqual(event_value.parent_index, [0, 0, 0, 1, 1])
   if use_string_view:
     self._check_string_view()
Exemplo n.º 18
0
 def test_create_expression_from_proto_with_any(self):
     """Test an any field."""
     expr = _get_expression_with_any()
     simple_expr = expr.get_descendant_or_error(
         path.Path([
             "my_any", "(type.googleapis.com/struct2tensor.test.AllSimple)"
         ]))
     child_node = expression_test_util.calculate_value_slowly(
         simple_expr).parent_index
     self.assertAllEqual(child_node, [0, 2])
Exemplo n.º 19
0
 def test_broadcast_and_calculate(self):
     """Tests get_sparse_tensors on a deep tree."""
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_big_prensor())
     new_root, new_path = broadcast.broadcast_anonymous(
         expr, path.Path(["foo"]), "user")
     new_field = new_root.get_descendant_or_error(new_path)
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     self.assertAllEqual(leaf_node.parent_index, [0, 1, 2, 3])
     self.assertAllEqual(leaf_node.values, [9, 8, 8, 7])
Exemplo n.º 20
0
 def test_create_size_field(self):
   with self.session(use_gpu=False) as sess:
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_big_prensor())
     new_root = expr.create_size_field("doc.bar", "result")
     new_field = new_root.get_descendant_or_error(path.Path(["doc", "result"]))
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     [parent_index,
      values] = sess.run([leaf_node.parent_index, leaf_node.values])
     self.assertAllEqual(parent_index, [0, 1, 2])
     self.assertAllEqual(values, [1, 2, 1])
Exemplo n.º 21
0
 def test_create_expression_from_proto_with_any_missing_message(self):
     """Test an any field with a message that is absent."""
     expr = _get_expression_with_any()
     simple_expr = expr.get_descendant_or_error(
         path.Path([
             "my_any",
             "(type.googleapis.com/struct2tensor.test.SpecialUserInfo)"
         ]))
     child_node = expression_test_util.calculate_value_slowly(
         simple_expr).parent_index
     self.assertAllEqual(child_node, [])
Exemplo n.º 22
0
 def _test_runner(options):
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_simple_prensor())
     new_root = map_prensor.map_sparse_tensor(expr, path.Path([]),
                                              [path.Path(["foo"])],
                                              lambda x: x * 2, True,
                                              tf.int32, "foo_doubled")
     leaf_node = expression_test_util.calculate_value_slowly(
         new_root.get_descendant_or_error(path.Path(["foo_doubled"])),
         options=options)
     self.evaluate(leaf_node.parent_index)
     self.evaluate(leaf_node.values)
Exemplo n.º 23
0
    def test_map_field_values_test(self):
        expr = create_expression.create_expression_from_prensor(
            prensor_test_util.create_simple_prensor())

        new_root = expr.map_field_values("foo", lambda x: x * 2, tf.int64,
                                         "foo_doubled")

        leaf_node = expression_test_util.calculate_value_slowly(
            new_root.get_descendant_or_error(path.Path(["foo_doubled"])))

        self.assertAllEqual(leaf_node.parent_index, [0, 1, 2])
        self.assertAllEqual(leaf_node.values, [18, 16, 14])
Exemplo n.º 24
0
 def test_promote(self):
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_nested_prensor())
     new_root = expr.promote("user.friends", "new_field")
     new_field = new_root.get_child_or_error("new_field")
     self.assertIsNotNone(new_field)
     self.assertTrue(new_field.is_repeated)
     self.assertEqual(new_field.type, tf.string)
     self.assertTrue(new_field.is_leaf)
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     self.assertEqual(leaf_node.values.dtype, tf.string)
     self.assertEqual(new_field.known_field_names(), frozenset())
Exemplo n.º 25
0
    def test_map_values_anonymous(self):
        expr = create_expression.create_expression_from_prensor(
            prensor_test_util.create_simple_prensor())

        new_root, p = map_values.map_values_anonymous(expr, path.Path(["foo"]),
                                                      lambda x: x * 2,
                                                      tf.int64)

        leaf_node = expression_test_util.calculate_value_slowly(
            new_root.get_descendant_or_error(p))
        self.assertAllEqual(leaf_node.parent_index, [0, 1, 2])
        self.assertAllEqual(leaf_node.values, [18, 16, 14])
Exemplo n.º 26
0
    def test_promote_and_calculate_leaf_then_substructure(self):
        """Tests promoting of leaf and then a substructure."""
        expr = create_expression.create_expression_from_prensor(
            prensor_test_util.create_four_layer_prensor())
        new_root, new_bar_path = promote.promote_anonymous(
            expr, path.Path(["event", "doc", "nested_child", "bar"]))
        new_root, new_path = promote.promote_anonymous(
            new_root, path.Path(["event", "doc"]))

        new_doc = new_root.get_descendant_or_error(new_path)
        new_bar = new_root.get_descendant_or_error(
            new_path.concat(new_bar_path.suffix(2)))
        bar_expr = new_root.get_descendant_or_error(
            new_path.concat(path.Path(["nested_child", "bar"])))
        keep_me_expr = new_root.get_descendant_or_error(
            new_path.concat(path.Path(["nested_child", "keep_me"])))

        new_doc_node = expression_test_util.calculate_value_slowly(new_doc)
        self.assertAllEqual(new_doc_node.parent_index, [0, 1, 1])
        self.assertTrue(new_doc_node.is_repeated)

        # new_bar's parent index is changed (from the first promote).
        # The second promote should not change new_bar's parent index.
        new_bar_node = expression_test_util.calculate_value_slowly(new_bar)
        self.assertAllEqual(new_bar_node.parent_index, [0, 1, 1, 1])
        self.assertAllEqual(new_bar_node.values, [b"a", b"b", b"c", b"d"])
        self.assertTrue(new_bar_node.is_repeated)

        # bar's parent index should be unchanged.
        bar_node = expression_test_util.calculate_value_slowly(bar_expr)
        self.assertAllEqual(bar_node.parent_index, [0, 1, 1, 2])
        self.assertAllEqual(bar_node.values, [b"a", b"b", b"c", b"d"])
        self.assertTrue(bar_node.is_repeated)

        # keep_me's parent index should be unchanged.
        keep_me_node = expression_test_util.calculate_value_slowly(
            keep_me_expr)
        self.assertAllEqual(keep_me_node.parent_index, [0, 1])
        self.assertAllEqual(keep_me_node.values, [False, True])
        self.assertFalse(keep_me_node.is_repeated)
Exemplo n.º 27
0
 def test_broadcast(self):
     """Tests broadcast.broadcast(...), and indirectly tests set_path."""
     expr = create_expression.create_expression_from_prensor(
         prensor_test_util.create_big_prensor())
     new_root = expr.broadcast("foo", "user", "new_field")
     new_field = new_root.get_child("user").get_child("new_field")
     self.assertIsNotNone(new_field)
     self.assertFalse(new_field.is_repeated)
     self.assertEqual(new_field.type, tf.int32)
     self.assertTrue(new_field.is_leaf)
     leaf_node = expression_test_util.calculate_value_slowly(new_field)
     self.assertEqual(leaf_node.values.dtype, tf.int32)
     self.assertEqual(new_field.known_field_names(), frozenset())
Exemplo n.º 28
0
 def test_user_info_with_extension(self):
     expr = _get_user_info_with_extension()
     ext_expr = expr.get_child_or_error(
         "(struct2tensor.test.MyExternalExtension.ext)")
     self.assertFalse(ext_expr.is_repeated)
     self.assertIsNone(ext_expr.type)
     self.assertFalse(ext_expr.is_leaf)
     self.assertFalse(ext_expr.calculation_is_identity())
     self.assertTrue(ext_expr.calculation_equal(ext_expr))
     self.assertFalse(ext_expr.calculation_equal(expr))
     child_node = expression_test_util.calculate_value_slowly(ext_expr)
     self.assertEqual(child_node.parent_index.dtype, tf.int64)
     self.assertEqual(ext_expr.known_field_names(), frozenset({"special"}))
Exemplo n.º 29
0
    def test_promote_substructure_then_leaf(self):
        """Tests expr.promote(...) of substructure and then a leaf."""
        expr = create_expression.create_expression_from_prensor(
            prensor_test_util.create_deep_prensor())
        new_root = (expr.promote(path.Path(["event", "doc"]),
                                 "new_field").promote(
                                     path.Path(["new_field", "bar"]),
                                     "new_bar"))

        new_bar = new_root.get_child_or_error("new_bar")
        self.assertIsNotNone(new_bar)
        self.assertTrue(new_bar.is_repeated)
        self.assertEqual(new_bar.type, tf.string)
        self.assertTrue(new_bar.is_leaf)

        new_field_bar = new_root.get_descendant_or_error(
            path.Path(["new_field", "bar"]))
        self.assertIsNotNone(new_field_bar)
        self.assertTrue(new_bar.is_repeated)
        self.assertEqual(new_bar.type, tf.string)
        self.assertTrue(new_bar.is_leaf)

        new_field_keep_me = new_root.get_descendant_or_error(
            path.Path(["new_field", "keep_me"]))
        self.assertIsNotNone(new_field_keep_me)
        self.assertFalse(new_field_keep_me.is_repeated)
        self.assertEqual(new_field_keep_me.type, tf.bool)
        self.assertTrue(new_field_keep_me.is_leaf)

        bar_node = expression_test_util.calculate_value_slowly(new_bar)
        self.assertEqual(bar_node.values.dtype, tf.string)

        new_field_bar_node = expression_test_util.calculate_value_slowly(
            new_field_bar)
        self.assertEqual(new_field_bar_node.values.dtype, tf.string)

        new_field_keep_me_node = expression_test_util.calculate_value_slowly(
            new_field_keep_me)
        self.assertEqual(new_field_keep_me_node.values.dtype, tf.bool)
Exemplo n.º 30
0
    def test_map_sparse_tensor(self):
        expr = create_expression.create_expression_from_prensor(
            prensor_test_util.create_simple_prensor())

        new_root = map_prensor.map_sparse_tensor(expr, path.Path([]),
                                                 [path.Path(["foo"])],
                                                 lambda x: x * 2, False,
                                                 tf.int32, "foo_doubled")

        leaf_node = expression_test_util.calculate_value_slowly(
            new_root.get_descendant_or_error(path.Path(["foo_doubled"])))
        self.assertAllEqual(leaf_node.parent_index, [0, 1, 2])
        self.assertAllEqual(leaf_node.values, [18, 16, 14])