Example #1
0
    def _get_rgb(self):
        # RGB image same as is displayed in .render()
        img = torch.ones_like(self.envs).short() * 255

        # Convert to BHWC axes for easier indexing here
        img = img.permute((0, 2, 3, 1))

        body_locations = (body(self.envs) > EPS).squeeze(1)
        img[body_locations, :] = self.body_colour

        head_locations = (head(self.envs) > EPS).squeeze(1)
        img[head_locations, :] = self.head_colour

        food_locations = (food(self.envs) > EPS).squeeze(1)
        img[food_locations, :] = self.food_colour

        img[:, :1, :, :] = self.edge_colour
        img[:, :, :1, :] = self.edge_colour
        img[:, -1:, :, :] = self.edge_colour
        img[:, :, -1:, :] = self.edge_colour

        # Convert back to BCHW axes
        img = img.permute((0, 3, 1, 2))

        return img
Example #2
0
    def _get_rgb(self):
        # RGB image same as is displayed in .render()
        img = torch.ones((self.num_envs, 3, self.size, self.size)).to(
            self.device).short() * 255

        # Convert to BHWC axes for easier indexing here
        img = img.permute((0, 2, 3, 1))

        body_locations = ((self._bodies > EPS).squeeze(1).sum(dim=1) >
                          EPS).byte()
        img[body_locations, :] = self.body_colour

        head_locations = ((self._heads > EPS).squeeze(1).sum(dim=1) >
                          EPS).byte()
        img[head_locations, :] = self.head_colour

        food_locations = (food(self.envs) > EPS).squeeze(1)
        img[food_locations, :] = self.food_colour

        img[:, :1, :, :] = self.edge_colour
        img[:, :, :1, :] = self.edge_colour
        img[:, -1:, :, :] = self.edge_colour
        img[:, :, -1:, :] = self.edge_colour

        # Convert back to BCHW axes
        img = img.permute((0, 3, 1, 2))

        return img
Example #3
0
    def test_eat_food(self):
        env = SingleSnake(num_envs=1, size=size, manual_setup=True)
        env.envs = get_test_env(size, 'up').to(DEFAULT_DEVICE)
        actions = torch.Tensor([0, 3, 3, 0,
                                0]).unsqueeze(1).long().to(DEFAULT_DEVICE)

        initial_size = body(env.envs).max()

        for i, a in enumerate(actions):
            if visualise:
                plot_envs(env.envs)
                plt.show()

            observations, reward, done, info = env.step(a)

            if torch.any(done):
                # The given actions shouldn't cause a death
                assert False

        final_size = body(env.envs).max()
        self.assertGreater(final_size, initial_size)

        # Food is created again after being eaten
        num_food = food(env.envs).sum()
        print(num_food, 1)
        self.assertEqual(num_food, 1)

        # Check overall consistency
        env_consistency(env.envs)

        if visualise:
            plot_envs(env.envs)
            plt.show()
Example #4
0
    def test_hit_self(self):
        env = SingleSnake(num_envs=1, size=size, manual_setup=True)
        env.envs = get_test_env(size, 'up').to(DEFAULT_DEVICE)
        actions = torch.Tensor([0, 3, 3, 2, 1, 0, 0,
                                0]).unsqueeze(1).long().to(DEFAULT_DEVICE)

        hit_self = False

        for i, a in enumerate(actions):
            if visualise:
                plot_envs(env.envs)
                plt.show()

            observations, reward, done, info = env.step(a)

            if torch.any(done):
                hit_self = True
                break

        self.assertTrue(hit_self)

        # Food is created agin after being eaten
        num_food = food(env.envs).sum()
        self.assertEqual(num_food, 1)

        if visualise:
            plot_envs(env.envs)
            plt.show()
Example #5
0
    def _get_rgb(self):
        # RGB image same as is displayed in .render()
        img = torch.zeros((self.num_envs, 3, self.size, self.size)).short().to(
            self.device).requires_grad_(False) * 255

        # Convert to BHWC axes for easier indexing here
        img = img.permute((0, 2, 3, 1))

        head_locations = (head(self.envs) > EPS).squeeze(1)
        img[head_locations, :] = self.head_colour

        food_locations = (food(self.envs) > EPS).squeeze(1)
        img[food_locations, :] = self.food_colour

        img[:, :1, :, :] = self.edge_colour
        img[:, :, :1, :] = self.edge_colour
        img[:, -1:, :, :] = self.edge_colour
        img[:, :, -1:, :] = self.edge_colour

        # Convert back to BCHW axes
        img = img.permute((0, 3, 1, 2))

        return img
Example #6
0
    def step(
            self,
            actions: Dict[str, torch.Tensor]) -> Tuple[dict, dict, dict, dict]:
        if len(actions) != self.num_snakes:
            raise RuntimeError('Must have a Tensor of actions for each snake')

        for agent, act in actions.items():
            if act.dtype not in (torch.short, torch.int, torch.long):
                raise TypeError(
                    'actions Tensor must be an integer type i.e. '
                    '{torch.ShortTensor, torch.IntTensor, torch.LongTensor}')

            if act.shape[0] != self.num_envs:
                raise RuntimeError(
                    'Must have the same number of actions as environments.')

        rewards = OrderedDict([(f'agent_{i}', torch.zeros(
            (self.num_envs, )).float().to(self.device).requires_grad_(False))
                               for i in range(self.num_snakes)])
        dones = OrderedDict([(f'agent_{i}', torch.zeros(
            (self.num_envs, )).float().to(
                self.device).byte().requires_grad_(False))
                             for i in range(self.num_snakes)])
        info = dict()

        snake_sizes = dict()
        for i, (agent, act) in enumerate(actions.items()):
            body_channel = self.body_channels[i]
            snake_sizes[agent] = self.envs[:, body_channel:body_channel +
                                           1, :].view(self.num_envs,
                                                      -1).max(dim=1)[0]

        # Check orientations and move head positions of all snakes
        for i, (agent, act) in enumerate(actions.items()):
            # The sub-environment of just one agent
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]

            orientations = determine_orientations(_env)

            # Check if this snake is trying to move backwards and change
            # it's direction/action to just continue forward
            # The test for this is if their orientation number {0, 1, 2, 3}
            # is the same as their action
            mask = orientations == act
            act.add_((mask * 2).long()).fmod_(4)

            # Create head position deltas
            head_deltas = F.conv2d(head(_env),
                                   ORIENTATION_FILTERS.to(self.device),
                                   padding=1)
            # Select the head position delta corresponding to the correct action
            actions_onehot = torch.Tensor(self.num_envs,
                                          4).float().to(self.device)
            actions_onehot.zero_()
            actions_onehot.scatter_(1, act.unsqueeze(-1), 1)
            head_deltas = torch.einsum(
                'bchw,bc->bhw', [head_deltas, actions_onehot]).unsqueeze(1)

            # Move head position by applying delta
            self.envs[:, head_channel:head_channel +
                      1, :, :].add_(head_deltas).round_()

        # Decay bodies of all snakes that haven't eaten food
        food_consumption = dict()
        for i, (agent, act) in enumerate(actions.items()):
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]

            head_food_overlap = (head(_env) * food(_env)).view(
                self.num_envs, -1).sum(dim=-1)
            food_consumption[agent] = head_food_overlap

            # Decay the body sizes by 1, hence moving the body, apply ReLu to keep above 0
            # Only do this for environments which haven't just eaten food
            body_decay_env_indices = ~head_food_overlap.byte()
            self.envs[body_decay_env_indices,
                      body_channel:body_channel + 1, :, :] -= 1
            self.envs[body_decay_env_indices, body_channel:body_channel + 1, :, :] = \
                self.envs[body_decay_env_indices, body_channel:body_channel + 1, :, :].relu()

        for i, (agent, act) in enumerate(actions.items()):
            # Check if any snakes have collided with themselves or any other snakes
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]

            # Collision with body of any snake
            body_collision = (head(_env) * self._bodies).view(
                self.num_envs, -1).sum(dim=-1) > EPS
            # Collision with head of other snake
            other_snakes = torch.ones(self.num_snakes).byte().to(self.device)
            other_snakes[i] = 0
            other_heads = self._heads[:, other_snakes, :, :]
            head_collision = (head(_env) * other_heads).view(
                self.num_envs, -1).sum(dim=-1) > EPS
            snake_collision = body_collision | head_collision
            info.update({f'self_collision_{i}': snake_collision})
            dones[agent] = dones[agent] | snake_collision

            # Create a new head position in the body channel
            # Make this head +1 greater if the snake has just eaten food
            self.envs[:, body_channel:body_channel + 1, :, :] += \
                head(_env) * (
                    snake_sizes[agent][:, None, None, None].expand((self.num_envs, 1, self.size, self.size)) +
                    food_consumption[agent][:, None, None, None].expand((self.num_envs, 1, self.size, self.size))
                )

        for i, (agent, act) in enumerate(actions.items()):
            # Remove food and give reward
            # `food_removal` is 0 except where a snake head is at the same location as food where it is -1
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]

            food_removal = head(_env) * food(_env) * -1
            rewards[agent].sub_(
                food_removal.view(self.num_envs, -1).sum(dim=-1).float())
            self.envs[:, FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_removal

        # Add new food if necessary.
        food_addition_env_indices = (food(self.envs).view(
            self.num_envs, -1).sum(dim=-1) < EPS)
        if food_addition_env_indices.sum().item() > 0:
            add_food_envs = self.envs[food_addition_env_indices, :, :, :]
            food_addition = self._get_food_addition(add_food_envs)
            self.envs[food_addition_env_indices,
                      FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_addition

        for i, (agent, act) in enumerate(actions.items()):
            # Check for boundary, Done by performing a convolution with no padding
            # If the head is at the edge then it will be cut off and the sum of the head
            # channel will be 0
            head_channel = self.head_channels[i]
            body_channel = self.body_channels[i]
            _env = self.envs[:, [0, head_channel, body_channel], :, :]
            edge_collision = F.conv2d(
                head(_env),
                NO_CHANGE_FILTER.to(self.device),
            ).view(self.num_envs, -1).sum(dim=-1) < EPS
            dones[agent] = dones[agent] | edge_collision
            info.update({f'edge_collision_{i}': edge_collision})

        for i, (agent, act) in enumerate(actions.items()):
            # Remove any snakes that are dead
            # self._bodies (num_envs, num_snakes, size, size)
            self._bodies[dones[agent], i, 0, 0] = 0
            self._heads[dones[agent], i, 0, 0] = 0

            # TODO:
            # Keep track of which snakes are already dead not just which have died
            # in the current step

        # Apply rounding to stop numerical errors accumulating
        self.envs.round_()

        # Environment is finished if all snake are dead
        dones['__all__'] = torch.ones((self.num_envs, )).float().to(
            self.device).byte().requires_grad_(False)
        for agent, act in actions.items():
            dones['__all__'] = dones['__all__'] & dones[agent]

        self.done = dones['__all__']

        return dict(), rewards, dones, info
Example #7
0
    def step(self, actions: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor, dict):
        if actions.dtype not in (torch.short, torch.int, torch.long):
            raise TypeError('actions Tensor must be an integer type i.e. '
                            '{torch.ShortTensor, torch.IntTensor, torch.LongTensor}')

        if actions.shape[0] != self.num_envs:
            raise RuntimeError('Must have the same number of actions as environments.')

        reward = torch.zeros((self.num_envs,)).float().to(self.device).requires_grad_(False)
        done = torch.zeros((self.num_envs,)).byte().to(self.device).byte().requires_grad_(False)
        info = dict()

        t0 = time()
        snake_sizes = self.envs[:, BODY_CHANNEL:BODY_CHANNEL + 1, :].view(self.num_envs, -1).max(dim=1)[0]

        orientations = determine_orientations(self.envs)
        if self.verbose > 0:
            print(f'\nOrientations: {time()-t0}s')

        t0 = time()
        # Check if any snakes are trying to move backwards and change
        # their direction/action to just continue forward
        # The test for this is if their orientation number {0, 1, 2, 3}
        # is the same as their action
        mask = orientations == actions
        actions.add_((mask * 2).long()).fmod_(4)

        # Create head position deltas
        head_deltas = F.conv2d(head(self.envs), ORIENTATION_FILTERS.to(self.device), padding=1)
        # Select the head position delta corresponding to the correct action
        actions_onehot = torch.FloatTensor(self.num_envs, 4).to(self.device)
        actions_onehot.zero_()
        actions_onehot.scatter_(1, actions.unsqueeze(-1), 1)
        head_deltas = torch.einsum('bchw,bc->bhw', [head_deltas, actions_onehot]).unsqueeze(1)

        # Move head position by applying delta
        self.envs[:, HEAD_CHANNEL:HEAD_CHANNEL + 1, :, :].add_(head_deltas).round_()
        if self.verbose:
            print(f'Head movement: {time() - t0}s')

        ################
        # Apply update #
        ################

        t0 = time()
        head_food_overlap = (head(self.envs) * food(self.envs)).view(self.num_envs, -1).sum(dim=-1)

        # Decay the body sizes by 1, hence moving the body, apply ReLu to keep above 0
        # Only do this for environments which haven't just eaten food
        body_decay_env_indices = ~head_food_overlap.byte()
        self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] -= 1
        self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] = \
            self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :].relu()

        # Check for hitting self
        self_collision = (head(self.envs) * body(self.envs)).view(self.num_envs, -1).sum(dim=-1) > EPS
        info.update({'self_collision': self_collision})
        done = done | self_collision

        # Create a new head position in the body channel
        # Make this head +1 greater if the snake has just eaten food
        self.envs[:, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] += \
            head(self.envs) * (
                snake_sizes[:, None, None, None].expand((self.num_envs, 1, self.size, self.size)) +
                head_food_overlap[:, None, None, None].expand((self.num_envs, 1, self.size, self.size))
            )

        if self.verbose:
            print(f'Body movement: {time()-t0}')

        t0 = time()
        # Remove food and give reward
        # `food_removal` is 0 except where a snake head is at the same location as food where it is -1
        food_removal = head(self.envs) * food(self.envs) * -1
        reward.sub_(food_removal.view(self.num_envs, -1).sum(dim=-1).float())
        self.envs[:, FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_removal
        if self.verbose:
            print(f'Food removal: {time() - t0}s')

        # Add new food if necessary.
        if food_removal.sum() < 0:
            t0 = time()
            food_addition_env_indices = (food_removal * -1).view(self.num_envs, -1).sum(dim=-1).byte()
            add_food_envs = self.envs[food_addition_env_indices, :, :, :]
            food_addition = self._get_food_addition(add_food_envs)
            self.envs[food_addition_env_indices, FOOD_CHANNEL:FOOD_CHANNEL+1, :, :] += food_addition
            if self.verbose:
                print(f'Food addition ({food_addition_env_indices.sum().item()} envs): {time() - t0}s')

        t0 = time()
        # Check for boundary, Done by performing a convolution with no padding
        # If the head is at the edge then it will be cut off and the sum of the head
        # channel will be 0
        edge_collision = F.conv2d(
            head(self.envs),
            NO_CHANGE_FILTER.to(self.device),
        ).view(self.num_envs, -1).sum(dim=-1) < EPS
        done = done | edge_collision
        info.update({'edge_collision': edge_collision})
        if self.verbose:
            print(f'Edge collision ({edge_collision.sum().item()} envs): {time() - t0}s')

        # Apply rounding to stop numerical errors accumulating
        self.envs.round_()

        self.done = done

        return self._observe(self.observation_mode), reward.unsqueeze(-1), done.unsqueeze(-1), info
Example #8
0
    def test_eat_food(self):
        env = get_test_env()

        # Add food
        env.envs[0, 0, 9, 7] = 1

        all_actions = {
            'agent_0':
            torch.Tensor([1, 2, 1, 1, 0,
                          3]).unsqueeze(1).long().to(DEFAULT_DEVICE),
            'agent_1':
            torch.Tensor([0, 1, 3, 2, 1,
                          0]).unsqueeze(1).long().to(DEFAULT_DEVICE),
        }

        print()
        if print_envs:
            print(env._bodies)

        if render_envs:
            env.render()
            sleep(render_sleep)

        for i in range(6):
            actions = {
                agent: agent_actions[i]
                for agent, agent_actions in all_actions.items()
            }

            observations, rewards, dones, info = env.step(actions)
            env.check_consistency()

            if print_envs:
                print('=' * 10)
                print(env._bodies)
                print('DONES:')
                print(dones)
                print()

            # Check reward given when expected
            if i == 0:
                self.assertEqual(rewards['agent_1'].item(), 1)

            if render_envs:
                env.render()
                sleep(render_sleep)

            if any(done for agent, done in dones.items()):
                # These actions shouldn't cause any deaths
                assert False

        # Check snake sizes. Expect agent_1: 4, agent_2: 5
        snake_sizes = env._bodies.view(1, 2, -1).max(dim=2)[0]
        self.assertTrue(
            torch.equal(snake_sizes,
                        torch.Tensor([[4, 5]]).to(DEFAULT_DEVICE)))

        # Check food has been removed
        self.assertEqual(env.envs[0, 0, 9, 7].item(), 0)

        # Check new food has been created
        self.assertEqual(food(env.envs).sum().item(), 1)
Example #9
0
    def step(
        self, actions: torch.Tensor
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, dict):
        if actions.dtype not in (torch.short, torch.int, torch.long):
            raise TypeError(
                'actions Tensor must be an integer type i.e. '
                '{torch.ShortTensor, torch.IntTensor, torch.LongTensor}')

        if actions.shape[0] != self.num_envs:
            raise RuntimeError(
                'Must have the same number of actions as environments.')

        reward = torch.zeros(
            (self.num_envs, )).float().to(self.device).requires_grad_(False)
        done = torch.zeros(
            (self.num_envs, )).byte().to(self.device).requires_grad_(False)
        info = dict()

        t0 = time()
        # Create head position deltas
        head_deltas = F.conv2d(head(self.envs),
                               ORIENTATION_FILTERS.to(self.device),
                               padding=1)
        # Select the head position delta corresponding to the correct action
        actions_onehot = torch.FloatTensor(self.num_envs, 4).to(self.device)
        actions_onehot.zero_()
        actions_onehot.scatter_(1, actions.unsqueeze(-1), 1)
        head_deltas = torch.einsum('bchw,bc->bhw',
                                   [head_deltas, actions_onehot]).unsqueeze(1)

        # Move head position by applying delta
        self.envs[:, HEAD_CHANNEL:HEAD_CHANNEL +
                  1, :, :].add_(head_deltas).round_()
        if self.verbose:
            print(f'Head movement: {time() - t0}s')

        ################
        # Apply update #
        ################

        t0 = time()
        # Remove food and give reward
        # `food_removal` is 0 except where a snake head is at the same location as food where it is -1
        food_removal = head(self.envs) * food(self.envs) * -1
        reward.sub_(food_removal.view(self.num_envs, -1).sum(dim=-1).float())
        self.envs[:, FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_removal
        if self.verbose:
            print(f'Food removal: {time() - t0}s')

        # Add new food if necessary.
        if food_removal.sum() < 0:
            t0 = time()
            food_addition_env_indices = (food_removal * -1).view(
                self.num_envs, -1).sum(dim=-1).byte()
            add_food_envs = self.envs[food_addition_env_indices, :, :, :]
            food_addition = self._get_food_addition(add_food_envs)
            self.envs[food_addition_env_indices,
                      FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_addition
            if self.verbose:
                print(
                    f'Food addition ({food_addition_env_indices.sum().item()} envs): {time() - t0}s'
                )

        t0 = time()
        # Check for boundary, Done by performing a convolution with no padding
        # If the head is at the edge then it will be cut off and the sum of the head
        # channel will be 0
        edge_collision = F.conv2d(
            head(self.envs),
            NO_CHANGE_FILTER.to(self.device),
        ).view(self.num_envs, -1).sum(dim=-1) < EPS
        done = done | edge_collision
        info.update({'edge_collision': edge_collision})
        if self.verbose:
            print(
                f'Edge collision ({edge_collision.sum().item()} envs): {time() - t0}s'
            )

        # Apply rounding to stop numerical errors accumulating
        self.envs.round_()

        self.done = done

        return self._observe(self.observation_mode), reward.unsqueeze(
            -1), done.unsqueeze(-1), info