Beispiel #1
0
    def build_feed_dict_batch_test(self, root):

        node_list = []
        tree_util.depth_first_traverse(
            root[0], node_list, lambda node, node_list: node_list.append(node))
        node_to_index = helper.reverse_dict(node_list)

        feed_dict = {
            self.root_array: [tree_util.size_of_tree(root[0]) - 1 + 1],
            self.is_leaf_array: [False] + [node.is_leaf for node in node_list],
            self.word_index_array: [0] + [
                self.data.word_embed_util.get_idx(node.value)
                for node in node_list
            ],
            self.left_child_array: [0] + helper.add_one([
                node_to_index[node.left_child]
                if node.left_child is not None else -1 for node in node_list
            ]),
            self.right_child_array: [0] + helper.add_one([
                node_to_index[node.right_child]
                if node.right_child is not None else -1 for node in node_list
            ]),
            self.label_array: [[0, 0]] + [node.label for node in node_list]
        }

        return feed_dict
Beispiel #2
0
    def build_feed_dict_batch(self, batch):
        node_list_list = []
        node_to_index_list = []
        root_array = []
        last_root = -1
        for root in batch:
            node_list = []
            tree_util.depth_first_traverse(root, node_list, lambda node, node_list: node_list.append(node))
            node_list_list.append(node_list)
            node_to_index = helper.reverse_dict(node_list)
            node_to_index_list.append(node_to_index)
            last_root += tree_util.size_of_tree(root)
            root_array.append(last_root)
        # TODO der er sikkert noget galt her

        feed_dict = {
            self.root_array: root_array,
            self.is_leaf_array: helper.flatten([[node.is_leaf for node in node_list] for node_list in node_list_list]),
            self.word_index_array: helper.flatten([[self.data.word_embed_util.get_idx(node.value) for node in node_list]
                                                   for node_list in node_list_list]),
            # todo måske wrap på en anden måde
            self.left_child_array: helper.flatten([
                [node_to_index[node.left_child] if node.left_child is not None else -1 for node in node_list]
                for node_list, node_to_index in zip(node_list_list, node_to_index_list)]),
            self.right_child_array: helper.flatten([
                [node_to_index[node.right_child] if node.right_child is not None else -1 for node in node_list]
                for node_list, node_to_index in zip(node_list_list, node_to_index_list)]),
            self.label_array: helper.flatten([[node.label for node in node_list]
                                              for node_list in node_list_list])
        }

        return feed_dict
Beispiel #3
0
    def build_feed_dict(self, roots, sort=True, train=False):
        if sort:
            roots_size = [tree_util.size_of_tree(root) for root in roots]
            roots = helper.sort_by(roots, roots_size)
        roots_size = [tree_util.size_of_tree(root) for root in roots]
        roots_list, permutation = helper.greedy_bin_packing(roots, roots_size, np.max(roots_size))

        node_list_list = []
        root_indices = []
        internal_nodes_array = []
        lstm_prev_list = []
        for i, roots in enumerate(roots_list):
            node_list = []
            root_index = 0
            leaf_index = 0
            lstm_prev = [0]
            lstm_prev_count = 0
            for root in roots:
                tree_util.depth_first_traverse(root, node_list, lambda node, node_list: node_list.append(node))
                leaf_count = tree_util.leafs_in_tree(root)
                root_index += leaf_count
                root_indices.append([i, root_index])
                for j in range(0,leaf_count):
                    leaf_index += 1
                    internal_nodes_array.append([i, leaf_index])

                leaf_count = tree_util.leafs_in_tree(root)
                for x in range(0, leaf_count):
                    if x == 0:
                        lstm_prev.append(0)
                    else:
                        lstm_prev.append(lstm_prev_count)
                    lstm_prev_count += 1

            node_list_list.append(node_list)
            lstm_prev_list.append(lstm_prev)

        feed_dict = {
            self.dropout_rate: FLAGS.dropout_prob if train else 0,
            self.lstm_prev_array: helper.lists_pad(lstm_prev_list, 0),
            self.leaf_word_array: helper.lists_pad(
                [[0] + [self.word_embed.get_idx(node.value) for node in node_list if node.is_leaf]
                for node_list in node_list_list]
            ,0),
            self.loss_array: root_indices if self.use_root_loss else internal_nodes_array,
            self.root_array: root_indices,
            self.label_array: helper.lists_pad([
                [[0, 0]] + [node.label for node in node_list if node.is_leaf]
                for node_list in node_list_list], [0, 0])
        }

        return feed_dict, permutation
Beispiel #4
0
    def build_feed_dict_batch(self, roots):
        print("Batch size:", len(roots))

        node_list_list = []
        node_to_index_list = []
        for root in roots:
            node_list = []
            tree_util.depth_first_traverse(
                root, node_list,
                lambda node, node_list: node_list.append(node))
            node_list_list.append(node_list)
            node_to_index = helper.reverse_dict(node_list)
            node_to_index_list.append(node_to_index)

        feed_dict = {
            self.root_array: [tree_util.size_of_tree(root) for root in roots],
            self.is_leaf_array:
            helper.lists_pad([[False] + [node.is_leaf for node in node_list]
                              for node_list in node_list_list], False),
            self.word_index_array:
            helper.lists_pad([[0] + [
                self.data.word_embed_util.get_idx(node.value)
                for node in node_list
            ] for node_list in node_list_list], 0),
            self.left_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.left_child]
                if node.left_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.right_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.right_child]
                if node.right_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.label_array:
            helper.lists_pad([[[0, 0]] + [node.label for node in node_list]
                              for node_list in node_list_list], [0, 0])
        }

        print(feed_dict[self.right_child_array])

        return feed_dict
Beispiel #5
0
    def build_feed_dict(self, root):
        node_list = []
        tree_util.depth_first_traverse(root, node_list, lambda node, node_list: node_list.append(node))
        node_to_index = helper.reverse_dict(node_list)
        # TODO der er sikkert noget galt her

        feed_dict = {
            self.root_array: [tree_util.size_of_tree(root) - 1],
            self.is_leaf_array: [node.is_leaf for node in node_list],
            self.word_index_array: [self.data.word_embed_util.get_idx(node.value) for node in node_list],
            # todo måske wrap på en anden måde
            self.left_child_array: [node_to_index[node.left_child] if node.left_child is not None else -1 for node in
                                    node_list],
            self.right_child_array: [node_to_index[node.right_child] if node.right_child is not None else -1 for node in
                                     node_list],
            self.label_array: [node.label for node in node_list]
        }

        return feed_dict
Beispiel #6
0
test_lstm_tree = [
    "(0 (0 (0 a) (0 b)) (0 (0 c) (0 d)))",
    "(1 (1 a) (1 (1 b) (1 (1 c) (1 d))))"
]

roots = [tree_util.parse_tree(tree) for tree in test_lstm_tree]

prev_idx = []
node_list = []
start = 0
for root in roots:
    _, start = tree_util.get_preceding_lstm_index(root, start, start, prev_idx)

    tree_util.depth_first_traverse(
        root, node_list,
        lambda node, node_list: node_list.append(int(node.is_leaf)))

import tensorflow as tf

sess = tf.Session()
embed = tf.constant([[1, 2], [10, 20], [100, 200]])

is_leaf = tf.squeeze(tf.constant([[1, 0]]))

get_node = tf.linalg.tensor_diag(is_leaf)
#tf.transpose(tf.reshape(tf.tile(tf.squeeze(is_leaf),[2]),[2,2]))

sess.run(is_leaf)
sess.run(embed)
sess.run(get_node)
    def build_feed_dict(self, roots, sort=True, train=False):
        if sort:
            roots_size = [tree_util.size_of_tree(root) for root in roots]
            roots = helper.sort_by(roots, roots_size)
        roots_size = [tree_util.size_of_tree(root) for root in roots]
        roots_list, permutation = helper.greedy_bin_packing(
            roots, roots_size, np.max(roots_size))

        node_list_list = []
        node_to_index_list = []
        root_indices = []
        lstm_idx_list = []
        internal_nodes_array = []
        for i, roots in enumerate(roots_list):
            node_list = []
            lstm_idx = [0]
            root_index = 0
            start = 0
            for root in roots:
                tree_util.depth_first_traverse(
                    root, node_list,
                    lambda node, node_list: node_list.append(node))

                _, start = tree_util.get_preceding_lstm_index(
                    root, start, start, lstm_idx)

                root_index += tree_util.size_of_tree(root)
                root_indices.append([i, root_index])
            node_list_list.append(node_list)
            node_to_index = helper.reverse_dict(node_list)
            node_to_index_list.append(node_to_index)
            lstm_idx_list.append(lstm_idx)
            for node in node_list:
                if not node.is_leaf:
                    internal_nodes_array.append([i, node_to_index[node] + 1])

        internal_nodes_array = internal_nodes_array if len(
            internal_nodes_array) > 0 else [[0, 0]]

        feed_dict = {
            self.dropout_rate:
            FLAGS.dropout_prob if train else 0,
            self.leaf_word_array:
            helper.lists_pad([[0] + [
                self.word_embed.get_idx(node.value)
                for node in node_list if node.is_leaf
            ] for node_list in node_list_list], 0),
            self.lstm_index_array:
            helper.lists_pad(lstm_idx_list, 0),
            self.loss_array:
            root_indices if self.use_root_loss else internal_nodes_array,
            self.root_array:
            root_indices,
            self.is_leaf_array:
            helper.lists_pad(
                [[0] + helper.to_int([node.is_leaf for node in node_list])
                 for node_list in node_list_list], 0),
            self.word_index_array:
            helper.lists_pad(
                [[0] +
                 [self.word_embed.get_idx(node.value) for node in node_list]
                 for node_list in node_list_list],
                self.word_embed.get_idx("ZERO")),
            self.left_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.left_child]
                if node.left_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.right_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.right_child]
                if node.right_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.label_array:
            helper.lists_pad([[[0, 0]] + [node.label for node in node_list]
                              for node_list in node_list_list], [0, 0])
        }

        return feed_dict, permutation