Ejemplo n.º 1
0
    def test_basic_movement(self):
        env = SingleSnake(num_envs=1, size=size, manual_setup=True)
        env.envs = get_test_env(size, 'up').to(DEFAULT_DEVICE)
        actions = torch.Tensor([0, 0, 3, 0, 0,
                                1]).unsqueeze(1).long().to(DEFAULT_DEVICE)
        expected_head_positions = torch.Tensor([[6, 4], [7, 4], [7, 5], [8, 5],
                                                [9, 5], [9, 4]])

        for i, a in enumerate(actions):
            if visualise:
                plot_envs(env.envs)
                plt.show()

            observations, reward, done, info = env.step(a)

            head_position = torch.Tensor([
                head(env.envs)[0, 0].flatten().argmax() // size,
                head(env.envs)[0, 0].flatten().argmax() % size
            ])

            self.assertTrue(
                torch.equal(head_position, expected_head_positions[i]))

            if torch.any(done):
                # The given actions shouldn't cause a death
                assert False

        if visualise:
            plot_envs(env.envs)
            plt.show()
Ejemplo n.º 2
0
    def _get_rgb(self):
        # RGB image same as is displayed in .render()
        img = torch.ones_like(self.envs).short() * 255

        # Convert to BHWC axes for easier indexing here
        img = img.permute((0, 2, 3, 1))

        body_locations = (body(self.envs) > EPS).squeeze(1)
        img[body_locations, :] = self.body_colour

        head_locations = (head(self.envs) > EPS).squeeze(1)
        img[head_locations, :] = self.head_colour

        food_locations = (food(self.envs) > EPS).squeeze(1)
        img[food_locations, :] = self.food_colour

        img[:, :1, :, :] = self.edge_colour
        img[:, :, :1, :] = self.edge_colour
        img[:, -1:, :, :] = self.edge_colour
        img[:, :, -1:, :] = self.edge_colour

        # Convert back to BCHW axes
        img = img.permute((0, 3, 1, 2))

        return img
Ejemplo n.º 3
0
    def test_basic_movement(self):
        env = SimpleGridworld(num_envs=1,
                              size=size,
                              start_location=(3, 3),
                              manual_setup=True)
        env.envs[0, FOOD_CHANNEL, 1, 1] = 1
        env.envs[0, HEAD_CHANNEL, 3, 3] = 1

        actions = torch.Tensor([0, 1, 2, 3, 2,
                                1]).unsqueeze(1).long().to(DEFAULT_DEVICE)
        expected_head_positions = torch.Tensor([[4, 3], [4, 2], [3, 2], [3, 3],
                                                [2, 3], [2, 2]])

        for i, a in enumerate(actions):

            observations, reward, done, info = env.step(a)

            head_position = torch.Tensor([
                head(env.envs)[0, 0].flatten().argmax() // size,
                head(env.envs)[0, 0].flatten().argmax() % size
            ])

            self.assertTrue(
                torch.equal(head_position, expected_head_positions[i]))
Ejemplo n.º 4
0
    def _get_rgb(self):
        # RGB image same as is displayed in .render()
        img = torch.zeros((self.num_envs, 3, self.size, self.size)).short().to(
            self.device).requires_grad_(False) * 255

        # Convert to BHWC axes for easier indexing here
        img = img.permute((0, 2, 3, 1))

        head_locations = (head(self.envs) > EPS).squeeze(1)
        img[head_locations, :] = self.head_colour

        food_locations = (food(self.envs) > EPS).squeeze(1)
        img[food_locations, :] = self.food_colour

        img[:, :1, :, :] = self.edge_colour
        img[:, :, :1, :] = self.edge_colour
        img[:, -1:, :, :] = self.edge_colour
        img[:, :, -1:, :] = self.edge_colour

        # Convert back to BCHW axes
        img = img.permute((0, 3, 1, 2))

        return img
Ejemplo n.º 5
0
    def step(
            self,
            actions: Dict[str, torch.Tensor]) -> Tuple[dict, dict, dict, dict]:
        if len(actions) != self.num_snakes:
            raise RuntimeError('Must have a Tensor of actions for each snake')

        for agent, act in actions.items():
            if act.dtype not in (torch.short, torch.int, torch.long):
                raise TypeError(
                    'actions Tensor must be an integer type i.e. '
                    '{torch.ShortTensor, torch.IntTensor, torch.LongTensor}')

            if act.shape[0] != self.num_envs:
                raise RuntimeError(
                    'Must have the same number of actions as environments.')

        rewards = OrderedDict([(f'agent_{i}', torch.zeros(
            (self.num_envs, )).float().to(self.device).requires_grad_(False))
                               for i in range(self.num_snakes)])
        dones = OrderedDict([(f'agent_{i}', torch.zeros(
            (self.num_envs, )).float().to(
                self.device).byte().requires_grad_(False))
                             for i in range(self.num_snakes)])
        info = dict()

        snake_sizes = dict()
        for i, (agent, act) in enumerate(actions.items()):
            body_channel = self.body_channels[i]
            snake_sizes[agent] = self.envs[:, body_channel:body_channel +
                                           1, :].view(self.num_envs,
                                                      -1).max(dim=1)[0]

        # Check orientations and move head positions of all snakes
        for i, (agent, act) in enumerate(actions.items()):
            # The sub-environment of just one agent
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]

            orientations = determine_orientations(_env)

            # Check if this snake is trying to move backwards and change
            # it's direction/action to just continue forward
            # The test for this is if their orientation number {0, 1, 2, 3}
            # is the same as their action
            mask = orientations == act
            act.add_((mask * 2).long()).fmod_(4)

            # Create head position deltas
            head_deltas = F.conv2d(head(_env),
                                   ORIENTATION_FILTERS.to(self.device),
                                   padding=1)
            # Select the head position delta corresponding to the correct action
            actions_onehot = torch.Tensor(self.num_envs,
                                          4).float().to(self.device)
            actions_onehot.zero_()
            actions_onehot.scatter_(1, act.unsqueeze(-1), 1)
            head_deltas = torch.einsum(
                'bchw,bc->bhw', [head_deltas, actions_onehot]).unsqueeze(1)

            # Move head position by applying delta
            self.envs[:, head_channel:head_channel +
                      1, :, :].add_(head_deltas).round_()

        # Decay bodies of all snakes that haven't eaten food
        food_consumption = dict()
        for i, (agent, act) in enumerate(actions.items()):
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]

            head_food_overlap = (head(_env) * food(_env)).view(
                self.num_envs, -1).sum(dim=-1)
            food_consumption[agent] = head_food_overlap

            # Decay the body sizes by 1, hence moving the body, apply ReLu to keep above 0
            # Only do this for environments which haven't just eaten food
            body_decay_env_indices = ~head_food_overlap.byte()
            self.envs[body_decay_env_indices,
                      body_channel:body_channel + 1, :, :] -= 1
            self.envs[body_decay_env_indices, body_channel:body_channel + 1, :, :] = \
                self.envs[body_decay_env_indices, body_channel:body_channel + 1, :, :].relu()

        for i, (agent, act) in enumerate(actions.items()):
            # Check if any snakes have collided with themselves or any other snakes
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]

            # Collision with body of any snake
            body_collision = (head(_env) * self._bodies).view(
                self.num_envs, -1).sum(dim=-1) > EPS
            # Collision with head of other snake
            other_snakes = torch.ones(self.num_snakes).byte().to(self.device)
            other_snakes[i] = 0
            other_heads = self._heads[:, other_snakes, :, :]
            head_collision = (head(_env) * other_heads).view(
                self.num_envs, -1).sum(dim=-1) > EPS
            snake_collision = body_collision | head_collision
            info.update({f'self_collision_{i}': snake_collision})
            dones[agent] = dones[agent] | snake_collision

            # Create a new head position in the body channel
            # Make this head +1 greater if the snake has just eaten food
            self.envs[:, body_channel:body_channel + 1, :, :] += \
                head(_env) * (
                    snake_sizes[agent][:, None, None, None].expand((self.num_envs, 1, self.size, self.size)) +
                    food_consumption[agent][:, None, None, None].expand((self.num_envs, 1, self.size, self.size))
                )

        for i, (agent, act) in enumerate(actions.items()):
            # Remove food and give reward
            # `food_removal` is 0 except where a snake head is at the same location as food where it is -1
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]

            food_removal = head(_env) * food(_env) * -1
            rewards[agent].sub_(
                food_removal.view(self.num_envs, -1).sum(dim=-1).float())
            self.envs[:, FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_removal

        # Add new food if necessary.
        food_addition_env_indices = (food(self.envs).view(
            self.num_envs, -1).sum(dim=-1) < EPS)
        if food_addition_env_indices.sum().item() > 0:
            add_food_envs = self.envs[food_addition_env_indices, :, :, :]
            food_addition = self._get_food_addition(add_food_envs)
            self.envs[food_addition_env_indices,
                      FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_addition

        for i, (agent, act) in enumerate(actions.items()):
            # Check for boundary, Done by performing a convolution with no padding
            # If the head is at the edge then it will be cut off and the sum of the head
            # channel will be 0
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]
            edge_collision = F.conv2d(
                head(_env),
                NO_CHANGE_FILTER.to(self.device),
            ).view(self.num_envs, -1).sum(dim=-1) < EPS
            dones[agent] = dones[agent] | edge_collision
            info.update({f'edge_collision_{i}': edge_collision})

        for i, (agent, act) in enumerate(actions.items()):
            # Remove any snakes that are dead
            # self._bodies (num_envs, num_snakes, size, size)
            self._bodies[dones[agent], i, 0, 0] = 0
            self._heads[dones[agent], i, 0, 0] = 0

            # TODO:
            # Keep track of which snakes are already dead not just which have died
            # in the current step

        # Apply rounding to stop numerical errors accumulating
        self.envs.round_()

        # Environment is finished if all snake are dead
        dones['__all__'] = torch.ones((self.num_envs, )).float().to(
            self.device).byte().requires_grad_(False)
        for agent, act in actions.items():
            dones['__all__'] = dones['__all__'] & dones[agent]

        self.done = dones['__all__']

        return dict(), rewards, dones, info
Ejemplo n.º 6
0
    def step(self, actions: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor, dict):
        if actions.dtype not in (torch.short, torch.int, torch.long):
            raise TypeError('actions Tensor must be an integer type i.e. '
                            '{torch.ShortTensor, torch.IntTensor, torch.LongTensor}')

        if actions.shape[0] != self.num_envs:
            raise RuntimeError('Must have the same number of actions as environments.')

        reward = torch.zeros((self.num_envs,)).float().to(self.device).requires_grad_(False)
        done = torch.zeros((self.num_envs,)).byte().to(self.device).byte().requires_grad_(False)
        info = dict()

        t0 = time()
        snake_sizes = self.envs[:, BODY_CHANNEL:BODY_CHANNEL + 1, :].view(self.num_envs, -1).max(dim=1)[0]

        orientations = determine_orientations(self.envs)
        if self.verbose > 0:
            print(f'\nOrientations: {time()-t0}s')

        t0 = time()
        # Check if any snakes are trying to move backwards and change
        # their direction/action to just continue forward
        # The test for this is if their orientation number {0, 1, 2, 3}
        # is the same as their action
        mask = orientations == actions
        actions.add_((mask * 2).long()).fmod_(4)

        # Create head position deltas
        head_deltas = F.conv2d(head(self.envs), ORIENTATION_FILTERS.to(self.device), padding=1)
        # Select the head position delta corresponding to the correct action
        actions_onehot = torch.FloatTensor(self.num_envs, 4).to(self.device)
        actions_onehot.zero_()
        actions_onehot.scatter_(1, actions.unsqueeze(-1), 1)
        head_deltas = torch.einsum('bchw,bc->bhw', [head_deltas, actions_onehot]).unsqueeze(1)

        # Move head position by applying delta
        self.envs[:, HEAD_CHANNEL:HEAD_CHANNEL + 1, :, :].add_(head_deltas).round_()
        if self.verbose:
            print(f'Head movement: {time() - t0}s')

        ################
        # Apply update #
        ################

        t0 = time()
        head_food_overlap = (head(self.envs) * food(self.envs)).view(self.num_envs, -1).sum(dim=-1)

        # Decay the body sizes by 1, hence moving the body, apply ReLu to keep above 0
        # Only do this for environments which haven't just eaten food
        body_decay_env_indices = ~head_food_overlap.byte()
        self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] -= 1
        self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] = \
            self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :].relu()

        # Check for hitting self
        self_collision = (head(self.envs) * body(self.envs)).view(self.num_envs, -1).sum(dim=-1) > EPS
        info.update({'self_collision': self_collision})
        done = done | self_collision

        # Create a new head position in the body channel
        # Make this head +1 greater if the snake has just eaten food
        self.envs[:, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] += \
            head(self.envs) * (
                snake_sizes[:, None, None, None].expand((self.num_envs, 1, self.size, self.size)) +
                head_food_overlap[:, None, None, None].expand((self.num_envs, 1, self.size, self.size))
            )

        if self.verbose:
            print(f'Body movement: {time()-t0}')

        t0 = time()
        # Remove food and give reward
        # `food_removal` is 0 except where a snake head is at the same location as food where it is -1
        food_removal = head(self.envs) * food(self.envs) * -1
        reward.sub_(food_removal.view(self.num_envs, -1).sum(dim=-1).float())
        self.envs[:, FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_removal
        if self.verbose:
            print(f'Food removal: {time() - t0}s')

        # Add new food if necessary.
        if food_removal.sum() < 0:
            t0 = time()
            food_addition_env_indices = (food_removal * -1).view(self.num_envs, -1).sum(dim=-1).byte()
            add_food_envs = self.envs[food_addition_env_indices, :, :, :]
            food_addition = self._get_food_addition(add_food_envs)
            self.envs[food_addition_env_indices, FOOD_CHANNEL:FOOD_CHANNEL+1, :, :] += food_addition
            if self.verbose:
                print(f'Food addition ({food_addition_env_indices.sum().item()} envs): {time() - t0}s')

        t0 = time()
        # Check for boundary, Done by performing a convolution with no padding
        # If the head is at the edge then it will be cut off and the sum of the head
        # channel will be 0
        edge_collision = F.conv2d(
            head(self.envs),
            NO_CHANGE_FILTER.to(self.device),
        ).view(self.num_envs, -1).sum(dim=-1) < EPS
        done = done | edge_collision
        info.update({'edge_collision': edge_collision})
        if self.verbose:
            print(f'Edge collision ({edge_collision.sum().item()} envs): {time() - t0}s')

        # Apply rounding to stop numerical errors accumulating
        self.envs.round_()

        self.done = done

        return self._observe(self.observation_mode), reward.unsqueeze(-1), done.unsqueeze(-1), info
Ejemplo n.º 7
0
    def test_basic_movement(self):
        env = get_test_env()

        # Add food
        env.envs[0, 0, 1, 1] = 1

        all_actions = {
            'agent_0':
            torch.Tensor([1, 2, 1, 1, 0,
                          3]).unsqueeze(1).long().to(DEFAULT_DEVICE),
            'agent_1':
            torch.Tensor([0, 1, 3, 2, 1,
                          0]).unsqueeze(1).long().to(DEFAULT_DEVICE),
        }
        expected_head_positions = [
            torch.Tensor([[5, 4], [4, 4], [4, 3], [4, 2], [5, 2], [5, 3]]),
            torch.Tensor([[9, 7], [9, 6], [9, 5], [8, 5], [8, 4], [9, 4]]),
        ]

        print()
        if print_envs:
            print(env._bodies)

        if render_envs:
            env.render()
            sleep(render_sleep)

        for i in range(6):
            actions = {
                agent: agent_actions[i]
                for agent, agent_actions in all_actions.items()
            }

            observations, rewards, dones, info = env.step(actions)
            env.check_consistency()

            for i_agent in range(env.num_snakes):
                head_channel = env.head_channels[i_agent]
                body_channel = env.body_channels[i_agent]
                _env = env.envs[:, [0, head_channel, body_channel], :, :]

                head_position = torch.Tensor([
                    head(_env)[0, 0].flatten().argmax() // size,
                    head(_env)[0, 0].flatten().argmax() % size
                ])
                self.assertTrue(
                    torch.equal(head_position,
                                expected_head_positions[i_agent][i]))
                # print(i_agent, head_position)

            if print_envs:
                print('=' * 10)
                print(env._bodies)
                print('DONES:')
                print(dones)
                print()

            if render_envs:
                env.render()
                sleep(render_sleep)

            if any(done for agent, done in dones.items()):
                # These actions shouldn't cause any deaths
                assert False
Ejemplo n.º 8
0
    def step(
        self, actions: torch.Tensor
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, dict):
        if actions.dtype not in (torch.short, torch.int, torch.long):
            raise TypeError(
                'actions Tensor must be an integer type i.e. '
                '{torch.ShortTensor, torch.IntTensor, torch.LongTensor}')

        if actions.shape[0] != self.num_envs:
            raise RuntimeError(
                'Must have the same number of actions as environments.')

        reward = torch.zeros(
            (self.num_envs, )).float().to(self.device).requires_grad_(False)
        done = torch.zeros(
            (self.num_envs, )).byte().to(self.device).requires_grad_(False)
        info = dict()

        t0 = time()
        # Create head position deltas
        head_deltas = F.conv2d(head(self.envs),
                               ORIENTATION_FILTERS.to(self.device),
                               padding=1)
        # Select the head position delta corresponding to the correct action
        actions_onehot = torch.FloatTensor(self.num_envs, 4).to(self.device)
        actions_onehot.zero_()
        actions_onehot.scatter_(1, actions.unsqueeze(-1), 1)
        head_deltas = torch.einsum('bchw,bc->bhw',
                                   [head_deltas, actions_onehot]).unsqueeze(1)

        # Move head position by applying delta
        self.envs[:, HEAD_CHANNEL:HEAD_CHANNEL +
                  1, :, :].add_(head_deltas).round_()
        if self.verbose:
            print(f'Head movement: {time() - t0}s')

        ################
        # Apply update #
        ################

        t0 = time()
        # Remove food and give reward
        # `food_removal` is 0 except where a snake head is at the same location as food where it is -1
        food_removal = head(self.envs) * food(self.envs) * -1
        reward.sub_(food_removal.view(self.num_envs, -1).sum(dim=-1).float())
        self.envs[:, FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_removal
        if self.verbose:
            print(f'Food removal: {time() - t0}s')

        # Add new food if necessary.
        if food_removal.sum() < 0:
            t0 = time()
            food_addition_env_indices = (food_removal * -1).view(
                self.num_envs, -1).sum(dim=-1).byte()
            add_food_envs = self.envs[food_addition_env_indices, :, :, :]
            food_addition = self._get_food_addition(add_food_envs)
            self.envs[food_addition_env_indices,
                      FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_addition
            if self.verbose:
                print(
                    f'Food addition ({food_addition_env_indices.sum().item()} envs): {time() - t0}s'
                )

        t0 = time()
        # Check for boundary, Done by performing a convolution with no padding
        # If the head is at the edge then it will be cut off and the sum of the head
        # channel will be 0
        edge_collision = F.conv2d(
            head(self.envs),
            NO_CHANGE_FILTER.to(self.device),
        ).view(self.num_envs, -1).sum(dim=-1) < EPS
        done = done | edge_collision
        info.update({'edge_collision': edge_collision})
        if self.verbose:
            print(
                f'Edge collision ({edge_collision.sum().item()} envs): {time() - t0}s'
            )

        # Apply rounding to stop numerical errors accumulating
        self.envs.round_()

        self.done = done

        return self._observe(self.observation_mode), reward.unsqueeze(
            -1), done.unsqueeze(-1), info
Ejemplo n.º 9
0
    def test_cant_boost_until_size_4(self):
        # Create a size 3 snake and try boosting with it
        env = MultiSnake(num_envs=1, num_snakes=2, size=size, manual_setup=True, boost=True)
        env.foods[:, 0, 1, 1] = 1
        # Snake 1
        env.heads[0, 0, 5, 5] = 1
        env.bodies[0, 0, 5, 5] = 3
        env.bodies[0, 0, 4, 5] = 2
        env.bodies[0, 0, 4, 4] = 1
        # Snake 2
        env.heads[1, 0, 8, 7] = 1
        env.bodies[1, 0, 8, 7] = 3
        env.bodies[1, 0, 8, 8] = 2
        env.bodies[1, 0, 8, 9] = 1

        # Get orientations manually
        _envs = torch.cat([
            env.foods.repeat_interleave(env.num_snakes, dim=0),
            env.heads,
            env.bodies
        ], dim=1)

        env.orientations = determine_orientations(_envs)

        expected_head_positions = torch.tensor([
            [6, 5],
            [6, 4],
            [5, 4],
        ])

        all_actions = {
            'agent_0': torch.tensor([4, 1, 2]).unsqueeze(1).long().to(DEFAULT_DEVICE),
            'agent_1': torch.tensor([0, 1, 3]).unsqueeze(1).long().to(DEFAULT_DEVICE),
        }

        print_or_render(env)

        for i in range(all_actions['agent_0'].shape[0]):
            actions = {
                agent: agent_actions[i] for agent, agent_actions in all_actions.items()
            }

            observations, rewards, dones, info = env.step(actions)

            env.reset(dones['__all__'])

            env.check_consistency()

            for i_agent in range(env.num_snakes):
                _env = torch.cat([
                    env.foods,
                    env.heads[i_agent].unsqueeze(0),
                    env.bodies[i_agent].unsqueeze(0)
                ], dim=1)

                head_position = torch.tensor([
                    head(_env)[0, 0].flatten().argmax() // size, head(_env)[0, 0].flatten().argmax() % size
                ])

                if i_agent == 0:
                    self.assertTrue(torch.equal(expected_head_positions[i], head_position))

            print_or_render(env)