def test_eat_food(self): env = SingleSnake(num_envs=1, size=size, manual_setup=True) env.envs = get_test_env(size, 'up').to(DEFAULT_DEVICE) actions = torch.Tensor([0, 3, 3, 0, 0]).unsqueeze(1).long().to(DEFAULT_DEVICE) initial_size = body(env.envs).max() for i, a in enumerate(actions): if visualise: plot_envs(env.envs) plt.show() observations, reward, done, info = env.step(a) if torch.any(done): # The given actions shouldn't cause a death assert False final_size = body(env.envs).max() self.assertGreater(final_size, initial_size) # Food is created again after being eaten num_food = food(env.envs).sum() print(num_food, 1) self.assertEqual(num_food, 1) # Check overall consistency env_consistency(env.envs) if visualise: plot_envs(env.envs) plt.show()
def check_consistency(self): """Runs multiple checks for environment consistency and throws an exception if any fail""" n = self.num_envs for i in range(self.num_snakes): head_channel = self.head_channels[i] body_channel = self.body_channels[i] _env = self.envs[:, [0, head_channel, body_channel], :, :] env_consistency(_env)
def test_setup(self): n = 97 env = SingleSnake(num_envs=n, size=size) env_consistency(env.envs) expected_body_sum = env.initial_snake_length * ( env.initial_snake_length + 1) / 2 self.assertTrue( torch.all( body(env.envs).view(n, -1).sum(dim=-1) == expected_body_sum))
def test_multiple_envs(self): num_envs = 100 num_steps = 100 env = SingleSnake(num_envs=num_envs, size=size) actions = torch.randint(4, size=(num_steps, num_envs)).long().to(DEFAULT_DEVICE) t0 = time() for i, a in enumerate(actions): if visualise: plot_envs(env.envs) plt.show() observations, reward, done, info = env.step(a) env.reset(done) env_consistency(env.envs) t = time() - t0 print( f'Ran {num_envs*num_steps} actions in {t}s = {num_envs*num_steps/t} actions/s' )
if args.save_video: recorder.capture_frame() ############################# # Interact with environment # ############################# probs, state_value = model(state) action_distribution = Categorical(probs) entropy = action_distribution.entropy().mean() action = action_distribution.sample().clone().long() state, reward, done, info = env.step(action) if args.env == 'snake': env_consistency(env.envs[~done.squeeze(-1)]) if args.agent != 'random' and args.train: trajectories.append(action=action, log_prob=action_distribution.log_prob(action), value=state_value, reward=reward, done=done, entropy=entropy) env.reset(done) ########################## # Advantage actor-critic # ########################## if args.agent != 'random' and args.train and i_step % args.update_steps == 0:
def test_reset(self): env = SingleSnake(num_envs=1, size=size) env_consistency(env.envs) env.reset(torch.Tensor([1]).to(DEFAULT_DEVICE)) env_consistency(env.envs)