Exemple #1
0
    def build_feed_dict(self, roots, sort=True, train=False):
        if sort:
            roots_size = [tree_util.size_of_tree(root) for root in roots]
            roots = helper.sort_by(roots, roots_size)
        roots_size = [tree_util.size_of_tree(root) for root in roots]
        roots_list, permutation = helper.greedy_bin_packing(roots, roots_size, np.max(roots_size))

        node_list_list = []
        root_indices = []
        internal_nodes_array = []
        lstm_prev_list = []
        for i, roots in enumerate(roots_list):
            node_list = []
            root_index = 0
            leaf_index = 0
            lstm_prev = [0]
            lstm_prev_count = 0
            for root in roots:
                tree_util.depth_first_traverse(root, node_list, lambda node, node_list: node_list.append(node))
                leaf_count = tree_util.leafs_in_tree(root)
                root_index += leaf_count
                root_indices.append([i, root_index])
                for j in range(0,leaf_count):
                    leaf_index += 1
                    internal_nodes_array.append([i, leaf_index])

                leaf_count = tree_util.leafs_in_tree(root)
                for x in range(0, leaf_count):
                    if x == 0:
                        lstm_prev.append(0)
                    else:
                        lstm_prev.append(lstm_prev_count)
                    lstm_prev_count += 1

            node_list_list.append(node_list)
            lstm_prev_list.append(lstm_prev)

        feed_dict = {
            self.dropout_rate: FLAGS.dropout_prob if train else 0,
            self.lstm_prev_array: helper.lists_pad(lstm_prev_list, 0),
            self.leaf_word_array: helper.lists_pad(
                [[0] + [self.word_embed.get_idx(node.value) for node in node_list if node.is_leaf]
                for node_list in node_list_list]
            ,0),
            self.loss_array: root_indices if self.use_root_loss else internal_nodes_array,
            self.root_array: root_indices,
            self.label_array: helper.lists_pad([
                [[0, 0]] + [node.label for node in node_list if node.is_leaf]
                for node_list in node_list_list], [0, 0])
        }

        return feed_dict, permutation
Exemple #2
0
    def build_feed_dict_batch(self, roots):
        print("Batch size:", len(roots))

        node_list_list = []
        node_to_index_list = []
        for root in roots:
            node_list = []
            tree_util.depth_first_traverse(
                root, node_list,
                lambda node, node_list: node_list.append(node))
            node_list_list.append(node_list)
            node_to_index = helper.reverse_dict(node_list)
            node_to_index_list.append(node_to_index)

        feed_dict = {
            self.root_array: [tree_util.size_of_tree(root) for root in roots],
            self.is_leaf_array:
            helper.lists_pad([[False] + [node.is_leaf for node in node_list]
                              for node_list in node_list_list], False),
            self.word_index_array:
            helper.lists_pad([[0] + [
                self.data.word_embed_util.get_idx(node.value)
                for node in node_list
            ] for node_list in node_list_list], 0),
            self.left_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.left_child]
                if node.left_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.right_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.right_child]
                if node.right_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.label_array:
            helper.lists_pad([[[0, 0]] + [node.label for node in node_list]
                              for node_list in node_list_list], [0, 0])
        }

        print(feed_dict[self.right_child_array])

        return feed_dict
    def build_feed_dict(self, roots, sort=True, train=False):
        if sort:
            roots_size = [tree_util.size_of_tree(root) for root in roots]
            roots = helper.sort_by(roots, roots_size)
        roots_size = [tree_util.size_of_tree(root) for root in roots]
        roots_list, permutation = helper.greedy_bin_packing(
            roots, roots_size, np.max(roots_size))

        node_list_list = []
        node_to_index_list = []
        root_indices = []
        lstm_idx_list = []
        internal_nodes_array = []
        for i, roots in enumerate(roots_list):
            node_list = []
            lstm_idx = [0]
            root_index = 0
            start = 0
            for root in roots:
                tree_util.depth_first_traverse(
                    root, node_list,
                    lambda node, node_list: node_list.append(node))

                _, start = tree_util.get_preceding_lstm_index(
                    root, start, start, lstm_idx)

                root_index += tree_util.size_of_tree(root)
                root_indices.append([i, root_index])
            node_list_list.append(node_list)
            node_to_index = helper.reverse_dict(node_list)
            node_to_index_list.append(node_to_index)
            lstm_idx_list.append(lstm_idx)
            for node in node_list:
                if not node.is_leaf:
                    internal_nodes_array.append([i, node_to_index[node] + 1])

        internal_nodes_array = internal_nodes_array if len(
            internal_nodes_array) > 0 else [[0, 0]]

        feed_dict = {
            self.dropout_rate:
            FLAGS.dropout_prob if train else 0,
            self.leaf_word_array:
            helper.lists_pad([[0] + [
                self.word_embed.get_idx(node.value)
                for node in node_list if node.is_leaf
            ] for node_list in node_list_list], 0),
            self.lstm_index_array:
            helper.lists_pad(lstm_idx_list, 0),
            self.loss_array:
            root_indices if self.use_root_loss else internal_nodes_array,
            self.root_array:
            root_indices,
            self.is_leaf_array:
            helper.lists_pad(
                [[0] + helper.to_int([node.is_leaf for node in node_list])
                 for node_list in node_list_list], 0),
            self.word_index_array:
            helper.lists_pad(
                [[0] +
                 [self.word_embed.get_idx(node.value) for node in node_list]
                 for node_list in node_list_list],
                self.word_embed.get_idx("ZERO")),
            self.left_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.left_child]
                if node.left_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.right_child_array:
            helper.lists_pad([[0] + helper.add_one([
                node_to_index[node.right_child]
                if node.right_child is not None else -1 for node in node_list
            ])
                              for node_list, node_to_index in zip(
                                  node_list_list, node_to_index_list)], 0),
            self.label_array:
            helper.lists_pad([[[0, 0]] + [node.label for node in node_list]
                              for node_list in node_list_list], [0, 0])
        }

        return feed_dict, permutation