コード例 #1
0
def test_is_leaf():
    """ Test if judging nodes to be leaves or not works """
    mpt3 = MPTWord("a b c 1 2 a 4 e 4 5 6")
    assert_equals(True, mpt3.is_leaf("3"))
    assert_equals(True, mpt3.is_leaf("13"))
    assert_equals(False, mpt3.is_leaf("p3"))
    assert_equals(False, mpt3.is_leaf("pq"))
コード例 #2
0
def test_abstract():
    """ Test the abstract formulation of MPT strings """
    mpt1 = MPTWord("a b c 1 2 a 4 e 4 5 d 6 7")
    mpt2 = MPTWord("pq b c 1 2 pq 4 e 4 5 z 6 7")
    mpt3 = MPTWord("a b c 1 2 a 4 e 4 5 6")

    assert_equals(mpt1.abstract(), "p0 p1 p2 1 2 p0 4 p3 4 5 p4 6 7")
    assert_equals(mpt2.abstract(), "p0 p1 p2 1 2 p0 4 p3 4 5 p4 6 7")
    assert_equals(mpt3.abstract(), "p0 p1 p2 1 2 p0 4 p3 4 5 6")
コード例 #3
0
def test_split():
    """ Test if splitting the word into positive and negative subtrees works """
    leaf = lambda x: all([ch in string.ascii_uppercase for ch in x])
    word1 = MPTWord("p A B", leaf_test=leaf)
    assert_equals(word1.split_pos_neg(), ('A', 'B'))
    word2 = MPTWord('a N b N O', leaf_test=leaf)
    assert_equals(word2.split_pos_neg(), ('N', 'b N O'))
    word3 = MPTWord('a b N O N', leaf_test=leaf)
    assert_equals(word3.split_pos_neg(), ('b N O', 'N'))

    word4 = context.PARSER.parse(MODEL_DIR + "/test1.model").word
    assert_equals(word4.split_pos_neg(), ('bc c 0 1 a 2 e 2 3', 'd 4 5'))
コード例 #4
0
def test_word_to_tree():
    """ Test the transformation from a word to a tree (root node) """
    root = Node("a", Node("bef", Node("a", Node("6"), Node("8")), Node("2")),
                Node("13"))
    root2 = transformations.word_to_nodes(MPTWord("a bef a 6 8 2 13"))

    assert_equals(MPT(root), MPT(root2))

    mpt = context.MPTS["2htms"]

    word = MPTWord(
        "y0 y5 y8 Do 0 G1 0 1 Dn 3 G1 2 3 y6 Do 4 G2 4 5 y7 Dn 7 G2 6 7 Do 8 G3 8 9 y1 y4 Dn 11 G3 10 11 Do 12 G4 12 13 y2 Dn 15 G4 14 15 y3 Do 16 G5 16 17 Dn 19 G5 18 19"
    )
    mpt2 = MPT(transformations.word_to_nodes(word))
    print(type(mpt))
    print(type(mpt2))
    assert_equals(mpt, mpt2)
コード例 #5
0
    def random_deletion_model(self):
        deletion_no = np.random.randint(0, self.no_del_trees)
        del_tree = self.deletion.read_number(deletion_no)

        del_tree = MPTWord(del_tree,
                           sep=self.mpt.word.sep,
                           leaf_test=self.mpt.word.is_leaf)

        return del_tree
コード例 #6
0
    def __init__(self, mpt, sep=" ", leaf_test=None):
        """ Constructs the MPT object.

        Parameters
        ----------
        mpt : [str, Node]
            either tree in bmpt or as root object.

        """

        self.subtrees = []
        self.word = None
        self.root = None

        # mpt given as word
        if isinstance(mpt, str):
            self.word = MPTWord(mpt, sep=sep, leaf_test=leaf_test)
            self.root = trans.word_to_nodes(self.word)

        # mpt given as root node
        else:
            self.root = mpt
            self.word = MPTWord(str(self))
コード例 #7
0
def test_list():
    word = MPTWord("a 1 b 1 2")
    assert_equals(list(word), ['a', '1', 'b', '1', '2'])
コード例 #8
0
def test_get_parameters():
    """ Test if returning all the parameters works """
    mpt2 = MPTWord("pq b c 1 2 pq 4 e 4 5 z 6 7")
    assert_equals(mpt2.parameters, ["pq", "b", "c", "pq", "e", "z"])
コード例 #9
0
def test_get_answers():
    """ Test if returning all the leaf node contents works """
    mpt2 = MPTWord("pq b c 1 2 pq 4 e 4 5 z 6 7")
    assert_equals(mpt2.answers, ["1", "2", "4", "4", "5", "6", "7"])
コード例 #10
0
class MPT(object):
    """ Multinomial Processing Tree (MPT) data structure.

    """
    def __init__(self, mpt, sep=" ", leaf_test=None):
        """ Constructs the MPT object.

        Parameters
        ----------
        mpt : [str, Node]
            either tree in bmpt or as root object.

        """

        self.subtrees = []
        self.word = None
        self.root = None

        # mpt given as word
        if isinstance(mpt, str):
            self.word = MPTWord(mpt, sep=sep, leaf_test=leaf_test)
            self.root = trans.word_to_nodes(self.word)

        # mpt given as root node
        else:
            self.root = mpt
            self.word = MPTWord(str(self))

    @property
    def params(self):
        return self.word.parameters

    @property
    def categories(self):
        return self.word.answers

    def formulae(self):
        """ Calculate the branch formulae for the categories in the tree

        Returns
        -------
        dict
            {category : formulae}

        """

        return trans.get_formulae(self)

    def max_parameters(self):
        """ The maximal number of free parameters in the model

        Returns
        -------
        int
            max number of free parameters

        """

        return sum([len(subtree) - 1 for subtree in self.subtrees])

    def get_levels(self, node, level=0):
        """ Generate a dict with all nodes and their respective level
        0 is the root

        Parameters
        ----------
        node : Node
            starting node

        level : int, optional
            level from which to start counting

        """

        levels = {level: [node]}

        if not node.leaf:
            left_dict = self.get_levels(node.pos, level=level + 1)
            right_dict = self.get_levels(node.neg, level=level + 1)

            temp = misc.merge_dicts(left_dict, right_dict)
            levels.update(temp)

        return levels

    def save(self, path, form="easy"):
        """ Saves the tree to a file

        Parameters
        ----------
        path : str
            where to save the tree

        """

        to_print = trans.to_easy(self) if form == "easy" else self.word.str_
        misc.write_iterable_to_file(path, to_print, newline=False)

    def draw(self):
        """ Draw MPT to the command line

        """

        cmd_draw(self)

    def __eq__(self, other):
        return self.word.abstract() == other.word.abstract()

    def __ne__(self, other):
        return not self.__eq__(other)

    def __str__(self):
        if self.word:
            return self.word.str_

        sep = " "

        def dfs(node):
            """ depth first search

            """

            if node.leaf:
                return str(node.content)

            pos = dfs(node.pos)
            neg = dfs(node.neg)
            return node.content + sep + pos + sep + neg

        return dfs(self.root)