コード例 #1
0
    def get_observation(self):
        state = []

        sprite_keys = list(self.notable_sprites)
        num_classes = len(sprite_keys)
        resource_types = self.game.domain.notable_resources

        for i, key in enumerate(sprite_keys):
            class_one_hot = [float(j == i) for j in range(num_classes)]

            # TODO this code is currently unsafe as getSprites does not
            # guarantee the same order for each call (Python < 3.6),
            # meaning observations will have inconsistent ordering of values
            for s in self.game.get_sprites(key):
                position = self._rect_to_pos(s.rect)
                if hasattr(s, 'orientation'):
                    orientation = [float(a) for a in s.orientation]
                else:
                    orientation = [0.0, 0.0]

                resources = [float(s.resources[r]) for r in resource_types]

                state += [
                    (s.id + '.position', position),
                    (s.id + '.orientation', orientation),
                    (s.id + '.class', class_one_hot),
                    (s.id + '.resources', resources),
                ]

        return KeyValueObservation(state)
コード例 #2
0
ファイル: state.py プロジェクト: lburger98/py-vgdl
    def get_observation(self):
        avatars = self.game.get_avatars()
        assert avatars
        avatar = avatars[0]

        avatar_pos = avatar.rect.topleft
        resources = [avatar.resources[r] for r in self.game.domain.notable_resources]

        sprite_distances = []
        for key in self.game.sprite_registry.sprite_keys:
            dist = 100
            for s in self.game.get_sprites(key):
                dist = min(self._get_distance(avatar, s)/self.game.block_size, dist)
            sprite_distances.append(dist)

        obs = KeyValueObservation(
            position=avatar_pos, speed=avatar.speed, resources=resources,
            distances=sprite_distances
        )
        return obs
コード例 #3
0
ファイル: state.py プロジェクト: lburger98/py-vgdl
    def get_observation(self):
        state = {'avatar': [('avatar.1.position', (-1, -1))], 'angry': [('angry.1.position', (-1, -1))]}

        sprite_keys = self.notable_sprites
        num_classes = len(sprite_keys)
        resource_types = self.game.domain.notable_resources

        for i, key in enumerate(sprite_keys):
            if key[0] == 'floor' or key[0] == 'wall' or key[0] == 'A':
                continue
            class_one_hot = [float(j==i) for j in range(num_classes)]

            # TODO this code is currently unsafe as getSprites does not
            # guarantee the same order for each call (Python < 3.6),
            # meaning observations will have inconsistent ordering of values
            for s in key[1]:

                position = self._rect_to_pos(s.rect)

                state[key[0]] = [
                    (s.id + '.position', position),
                ]

        return KeyValueObservation(state['angry'] + state['avatar'])
コード例 #4
0
ファイル: gapworld.py プロジェクト: sahil02235/symbolic-rl
 def get_observation(self):
     avatar = self._game.getAvatars()[0]
     position = self._rect_to_pos(avatar.rect)
     obs = KeyValueObservation(x=position[0])
     return obs