예제 #1
0
class TestIntegration(unittest.TestCase):
    def setUp(self):
        vtypes = parse_variable_types(testing.VARIABLES_TXT_CONTENT)
        self.types = VariableTypeTree(vtypes)

    def test_parse(self):
        assert len(self.types) == 12
        assert len(self.types.constants) == 2
        assert len(self.types.variables) == 10

    def test_variable_types(self):
        assert self.types.is_constant("P")
        assert self.types.is_constant("I")
        assert set(self.types.descendants('o')) == set(['f', 'k'])
        assert set(self.types.descendants('c')) == set(['oven'])
        assert self.types.is_descendant_of('c', 't')
        assert self.types.is_descendant_of('s', 't')

    def test_parents(self):
        for vtype in self.types:
            for parent in self.types.get_ancestors(vtype):
                assert self.types.is_descendant_of(vtype, parent)

    def test_sample(self):
        rng = np.random.RandomState(1234)
        vtype = self.types.sample("f", rng, include_parent=True)
        assert vtype == "f"

        assert_raises(ValueError,
                      self.types.sample,
                      "f",
                      rng,
                      include_parent=False)

        for t in self.types:
            for _ in range(30):
                vtype = self.types.sample(t, rng, include_parent=True)
                assert self.types.is_descendant_of(vtype, t)

    def test_count(self):
        rng = np.random.RandomState(1234)
        types_counts = {t: rng.randint(2, 10) for t in self.types.variables}

        state = State()
        for t in self.types.variables:
            v = Variable(get_new(t, types_counts), t)
            state.add_fact(Proposition("dummy", [v]))

        counts = self.types.count(state)
        for t in self.types.variables:
            assert counts[t] == types_counts[t], (counts[t], types_counts[t])

    def test_serialization_deserialization(self):
        data = self.types.serialize()
        types2 = VariableTypeTree.deserialize(data)
        assert types2.variables_types == types2.variables_types
예제 #2
0
def _to_type_tree(types):
    vtypes = []

    for vtype in sorted(types):
        if vtype.parents:
            parent = vtype.parents[0]
        else:
            parent = None
        vtypes.append(VariableType(vtype.name, vtype.name, parent))

    return VariableTypeTree(vtypes)
예제 #3
0
파일: game.py 프로젝트: zp312/TextWorld
    def deserialize(cls, data: Mapping) -> "Game":
        """ Creates a `Game` from serialized data.

        Args:
            data: Serialized data with the needed information to build a
                  `Game` object.
        """
        world = World.deserialize(data["world"])
        grammar = None
        if "grammar" in data:
            grammar = Grammar(data["grammar"])
        quests = [Quest.deserialize(d) for d in data["quests"]]
        game = cls(world, grammar, quests)
        game._infos = {k: EntityInfo.deserialize(v) for k, v in data["infos"]}
        game.state = State.deserialize(data["state"])
        game._rules = {k: Rule.deserialize(v) for k, v in data["rules"]}
        game._types = VariableTypeTree.deserialize(data["types"])
        game.metadata = data.get("metadata", {})

        return game
예제 #4
0
 def test_serialization_deserialization(self):
     data = self.types.serialize()
     types2 = VariableTypeTree.deserialize(data)
     assert types2.variables_types == types2.variables_types
예제 #5
0
 def setUp(self):
     vtypes = parse_variable_types(testing.VARIABLES_TXT_CONTENT)
     self.types = VariableTypeTree(vtypes)