コード例 #1
0
    def build_feed_dict_batch_test(self, root):

        node_list = []
        tree_util.depth_first_traverse(
            root[0], node_list, lambda node, node_list: node_list.append(node))
        node_to_index = helper.reverse_dict(node_list)

        feed_dict = {
            self.root_array: [tree_util.size_of_tree(root[0]) - 1 + 1],
            self.is_leaf_array: [False] + [node.is_leaf for node in node_list],
            self.word_index_array: [0] + [
                self.data.word_embed_util.get_idx(node.value)
                for node in node_list
            ],
            self.left_child_array: [0] + helper.add_one([
                node_to_index[node.left_child]
                if node.left_child is not None else -1 for node in node_list
            ]),
            self.right_child_array: [0] + helper.add_one([
                node_to_index[node.right_child]
                if node.right_child is not None else -1 for node in node_list
            ]),
            self.label_array: [[0, 0]] + [node.label for node in node_list]
        }

        return feed_dict
コード例 #2
0
    def build_feed_dict_batch(self, roots):
        print("Batch size:", len(roots))

        node_list_list = []
        node_to_index_list = []
        for root in roots:
            node_list = []
            tree_util.depth_first_traverse(
                root, node_list,
                lambda node, node_list: node_list.append(node))
            node_list_list.append(node_list)
            node_to_index = helper.reverse_dict(node_list)
            node_to_index_list.append(node_to_index)

        feed_dict = {
            self.root_array: [tree_util.size_of_tree(root) for root in roots],
            self.is_leaf_array:
            helper.lists_pad([[False] + [node.is_leaf for node in node_list]
                              for node_list in node_list_list], False),
            self.word_index_array:
            helper.lists_pad([[0] + [
                self.data.word_embed_util.get_idx(node.value)
                for node in node_list
            ] for node_list in node_list_list], 0),
            self.left_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.left_child]
                if node.left_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.right_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.right_child]
                if node.right_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.label_array:
            helper.lists_pad([[[0, 0]] + [node.label for node in node_list]
                              for node_list in node_list_list], [0, 0])
        }

        print(feed_dict[self.right_child_array])

        return feed_dict
コード例 #3
0
    def build_feed_dict(self, roots, sort=True, train=False):
        if sort:
            roots_size = [tree_util.size_of_tree(root) for root in roots]
            roots = helper.sort_by(roots, roots_size)
        roots_size = [tree_util.size_of_tree(root) for root in roots]
        roots_list, permutation = helper.greedy_bin_packing(
            roots, roots_size, np.max(roots_size))

        node_list_list = []
        node_to_index_list = []
        root_indices = []
        lstm_idx_list = []
        internal_nodes_array = []
        for i, roots in enumerate(roots_list):
            node_list = []
            lstm_idx = [0]
            root_index = 0
            start = 0
            for root in roots:
                tree_util.depth_first_traverse(
                    root, node_list,
                    lambda node, node_list: node_list.append(node))

                _, start = tree_util.get_preceding_lstm_index(
                    root, start, start, lstm_idx)

                root_index += tree_util.size_of_tree(root)
                root_indices.append([i, root_index])
            node_list_list.append(node_list)
            node_to_index = helper.reverse_dict(node_list)
            node_to_index_list.append(node_to_index)
            lstm_idx_list.append(lstm_idx)
            for node in node_list:
                if not node.is_leaf:
                    internal_nodes_array.append([i, node_to_index[node] + 1])

        internal_nodes_array = internal_nodes_array if len(
            internal_nodes_array) > 0 else [[0, 0]]

        feed_dict = {
            self.dropout_rate:
            FLAGS.dropout_prob if train else 0,
            self.leaf_word_array:
            helper.lists_pad([[0] + [
                self.word_embed.get_idx(node.value)
                for node in node_list if node.is_leaf
            ] for node_list in node_list_list], 0),
            self.lstm_index_array:
            helper.lists_pad(lstm_idx_list, 0),
            self.loss_array:
            root_indices if self.use_root_loss else internal_nodes_array,
            self.root_array:
            root_indices,
            self.is_leaf_array:
            helper.lists_pad(
                [[0] + helper.to_int([node.is_leaf for node in node_list])
                 for node_list in node_list_list], 0),
            self.word_index_array:
            helper.lists_pad(
                [[0] +
                 [self.word_embed.get_idx(node.value) for node in node_list]
                 for node_list in node_list_list],
                self.word_embed.get_idx("ZERO")),
            self.left_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.left_child]
                if node.left_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.right_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.right_child]
                if node.right_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.label_array:
            helper.lists_pad([[[0, 0]] + [node.label for node in node_list]
                              for node_list in node_list_list], [0, 0])
        }

        return feed_dict, permutation