def test_eat_food(self): env = SingleSnake(num_envs=1, size=size, manual_setup=True) env.envs = get_test_env(size, 'up').to(DEFAULT_DEVICE) actions = torch.Tensor([0, 3, 3, 0, 0]).unsqueeze(1).long().to(DEFAULT_DEVICE) initial_size = body(env.envs).max() for i, a in enumerate(actions): if visualise: plot_envs(env.envs) plt.show() observations, reward, done, info = env.step(a) if torch.any(done): # The given actions shouldn't cause a death assert False final_size = body(env.envs).max() self.assertGreater(final_size, initial_size) # Food is created again after being eaten num_food = food(env.envs).sum() print(num_food, 1) self.assertEqual(num_food, 1) # Check overall consistency env_consistency(env.envs) if visualise: plot_envs(env.envs) plt.show()
def _get_rgb(self): # RGB image same as is displayed in .render() img = torch.ones_like(self.envs).short() * 255 # Convert to BHWC axes for easier indexing here img = img.permute((0, 2, 3, 1)) body_locations = (body(self.envs) > EPS).squeeze(1) img[body_locations, :] = self.body_colour head_locations = (head(self.envs) > EPS).squeeze(1) img[head_locations, :] = self.head_colour food_locations = (food(self.envs) > EPS).squeeze(1) img[food_locations, :] = self.food_colour img[:, :1, :, :] = self.edge_colour img[:, :, :1, :] = self.edge_colour img[:, -1:, :, :] = self.edge_colour img[:, :, -1:, :] = self.edge_colour # Convert back to BCHW axes img = img.permute((0, 3, 1, 2)) return img
def test_setup(self): n = 97 env = SingleSnake(num_envs=n, size=size) env_consistency(env.envs) expected_body_sum = env.initial_snake_length * ( env.initial_snake_length + 1) / 2 self.assertTrue( torch.all( body(env.envs).view(n, -1).sum(dim=-1) == expected_body_sum))
def step(self, actions: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor, dict): if actions.dtype not in (torch.short, torch.int, torch.long): raise TypeError('actions Tensor must be an integer type i.e. ' '{torch.ShortTensor, torch.IntTensor, torch.LongTensor}') if actions.shape[0] != self.num_envs: raise RuntimeError('Must have the same number of actions as environments.') reward = torch.zeros((self.num_envs,)).float().to(self.device).requires_grad_(False) done = torch.zeros((self.num_envs,)).byte().to(self.device).byte().requires_grad_(False) info = dict() t0 = time() snake_sizes = self.envs[:, BODY_CHANNEL:BODY_CHANNEL + 1, :].view(self.num_envs, -1).max(dim=1)[0] orientations = determine_orientations(self.envs) if self.verbose > 0: print(f'\nOrientations: {time()-t0}s') t0 = time() # Check if any snakes are trying to move backwards and change # their direction/action to just continue forward # The test for this is if their orientation number {0, 1, 2, 3} # is the same as their action mask = orientations == actions actions.add_((mask * 2).long()).fmod_(4) # Create head position deltas head_deltas = F.conv2d(head(self.envs), ORIENTATION_FILTERS.to(self.device), padding=1) # Select the head position delta corresponding to the correct action actions_onehot = torch.FloatTensor(self.num_envs, 4).to(self.device) actions_onehot.zero_() actions_onehot.scatter_(1, actions.unsqueeze(-1), 1) head_deltas = torch.einsum('bchw,bc->bhw', [head_deltas, actions_onehot]).unsqueeze(1) # Move head position by applying delta self.envs[:, HEAD_CHANNEL:HEAD_CHANNEL + 1, :, :].add_(head_deltas).round_() if self.verbose: print(f'Head movement: {time() - t0}s') ################ # Apply update # ################ t0 = time() head_food_overlap = (head(self.envs) * food(self.envs)).view(self.num_envs, -1).sum(dim=-1) # Decay the body sizes by 1, hence moving the body, apply ReLu to keep above 0 # Only do this for environments which haven't just eaten food body_decay_env_indices = ~head_food_overlap.byte() self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] -= 1 self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] = \ self.envs[body_decay_env_indices, BODY_CHANNEL:BODY_CHANNEL + 1, :, :].relu() # Check for hitting self self_collision = (head(self.envs) * body(self.envs)).view(self.num_envs, -1).sum(dim=-1) > EPS info.update({'self_collision': self_collision}) done = done | self_collision # Create a new head position in the body channel # Make this head +1 greater if the snake has just eaten food self.envs[:, BODY_CHANNEL:BODY_CHANNEL + 1, :, :] += \ head(self.envs) * ( snake_sizes[:, None, None, None].expand((self.num_envs, 1, self.size, self.size)) + head_food_overlap[:, None, None, None].expand((self.num_envs, 1, self.size, self.size)) ) if self.verbose: print(f'Body movement: {time()-t0}') t0 = time() # Remove food and give reward # `food_removal` is 0 except where a snake head is at the same location as food where it is -1 food_removal = head(self.envs) * food(self.envs) * -1 reward.sub_(food_removal.view(self.num_envs, -1).sum(dim=-1).float()) self.envs[:, FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_removal if self.verbose: print(f'Food removal: {time() - t0}s') # Add new food if necessary. if food_removal.sum() < 0: t0 = time() food_addition_env_indices = (food_removal * -1).view(self.num_envs, -1).sum(dim=-1).byte() add_food_envs = self.envs[food_addition_env_indices, :, :, :] food_addition = self._get_food_addition(add_food_envs) self.envs[food_addition_env_indices, FOOD_CHANNEL:FOOD_CHANNEL+1, :, :] += food_addition if self.verbose: print(f'Food addition ({food_addition_env_indices.sum().item()} envs): {time() - t0}s') t0 = time() # Check for boundary, Done by performing a convolution with no padding # If the head is at the edge then it will be cut off and the sum of the head # channel will be 0 edge_collision = F.conv2d( head(self.envs), NO_CHANGE_FILTER.to(self.device), ).view(self.num_envs, -1).sum(dim=-1) < EPS done = done | edge_collision info.update({'edge_collision': edge_collision}) if self.verbose: print(f'Edge collision ({edge_collision.sum().item()} envs): {time() - t0}s') # Apply rounding to stop numerical errors accumulating self.envs.round_() self.done = done return self._observe(self.observation_mode), reward.unsqueeze(-1), done.unsqueeze(-1), info