def test_random_actions(self):
        num_envs = 100
        num_steps = 100

        obs_h = 5
        obs_w = 5
        obs_fn = observations.FirstPersonCrop(obs_h, obs_w)

        # Create some environments and run random actions for N steps, checking for consistency at each step
        env = Slither(num_envs=num_envs, num_agents=2, height=size, width=size, manual_setup=False, verbose=True,
                      render_args={'num_rows': 5, 'num_cols': 5, 'size': 128}, observation_fn=obs_fn
                      )
        env.check_consistency()

        all_actions = {
            'agent_0': torch.randint(4, size=(num_steps, num_envs)).long().to(DEFAULT_DEVICE),
            'agent_1': torch.randint(4, size=(num_steps, num_envs)).long().to(DEFAULT_DEVICE),
        }

        t0 = time()
        for i in range(all_actions['agent_0'].shape[0]):
            actions = {
                agent: agent_actions[i] for agent, agent_actions in all_actions.items()
            }
            obs, reward, done, info = env.step(actions)

            env.reset(done['__all__'])
            env.check_consistency()
            print()

        t = time() - t0
        print(f'Ran {num_envs * num_steps} actions in {t}s = {num_envs * num_steps / t} actions/s')
    def test_random_actions_with_boost(self):
        # num_envs = 1024*8
        num_envs = 256
        num_steps = 200
        num_snakes = 4

        obs_h = 5
        obs_w = 5
        obs_fn = observations.FirstPersonCrop(obs_h, obs_w)
        # Create some environments and run random actions for N steps, checking for consistency at each step
        env = Slither(num_envs=num_envs, num_agents=num_snakes, height=25, width=25,  manual_setup=False, boost=True,
                      verbose=True,  render_args={'num_rows': 1, 'num_cols': 2, 'size': 256},
                      respawn_mode='any', food_mode='random_rate', boost_cost_prob=0.25,
                      observation_fn=obs_fn, food_on_death_prob=0.33, food_rate=2.5e-4
                      )
        env.check_consistency()

        all_actions = {
            f'agent_{i}': torch.randint(8, size=(num_steps, num_envs)).long().to(DEFAULT_DEVICE) for i in
            range(num_snakes)
        }

        t0 = time()
        for i in range(all_actions['agent_0'].shape[0]):
            actions = {
                agent: agent_actions[i] for agent, agent_actions in all_actions.items()
            }
            obs, reward, done, info = env.step(actions)

            env.reset(done['__all__'], return_observations=False)
            env.check_consistency()
            print()

        t = time() - t0
        print(f'Ran {num_envs * num_steps} actions in {t}s = {num_envs * num_steps / t} actions/s')
    def test_partial_observations(self):
        num_envs = 256
        num_snakes = 4

        obs_h = 5
        obs_w = 5
        obs_fn = observations.FirstPersonCrop(obs_h, obs_w)
        env = Slither(num_envs=num_envs, num_agents=num_snakes, height=25, width=25, manual_setup=False, boost=True,
                      observation_fn=obs_fn,
                      render_args={'num_rows': 1, 'num_cols': 2, 'size': 256},
                      )
        env.check_consistency()

        # render_envs = True
        agent_obs = env._observe()
        if render_envs:
        # if True:
            fig, axes = plt.subplots(2, 2)
            i = 0
            # Show all the observations of the agent in the first env
            for k, v in agent_obs.items():
                axes[i // 2, i % 2].imshow(v[0].permute(1, 2, 0).cpu().numpy())
                i += 1

            plt.show()

            env.render()
            sleep(5)
def get_test_env(num_envs=1):
    env = Slither(num_envs=num_envs, num_agents=2, height=size, width=size, manual_setup=True)

    for i in range(num_envs):
        # Snake 1
        env.agents[2 * i, 0, 5, 5] = 1
        env.bodies[2*i, 0, 5, 5] = 4
        env.bodies[2*i, 0, 4, 5] = 3
        env.bodies[2*i, 0, 4, 4] = 2
        env.bodies[2*i, 0, 4, 3] = 1
        # Snake 2
        env.agents[2 * i + 1, 0, 8, 7] = 1
        env.bodies[2*i+1, 0, 8, 7] = 4
        env.bodies[2*i+1, 0, 8, 8] = 3
        env.bodies[2*i+1, 0, 8, 9] = 2
        env.bodies[2*i+1, 0, 9, 9] = 1

    _envs = torch.cat([
        env.foods.repeat_interleave(env.num_agents, dim=0),
        env.agents,
        env.bodies
    ], dim=1)

    env.orientations = determine_orientations(_envs)
    print(env.orientations)

    return env
    def test_create_envs(self):
        # Create a large number of environments and check consistency
        env = Slither(num_envs=512, num_agents=2, height=size, width=size, manual_setup=False)
        env.check_consistency()

        _envs = torch.cat([
            env.foods.repeat_interleave(env.num_agents, dim=0),
            env.agents,
            env.bodies
        ], dim=1)

        orientations = determine_orientations(_envs)
        self.assertTrue(torch.equal(env.orientations, orientations))
    def test_many_snakes(self):
        num_envs = 50
        num_steps = 10
        num_snakes = 4
        env = Slither(num_envs=num_envs, num_agents=num_snakes, height=size, width=size, manual_setup=False, boost=True)
        env.check_consistency()

        all_actions = {
            f'agent_{i}': torch.randint(8, size=(num_steps, num_envs)).long().to(DEFAULT_DEVICE) for i in range(num_snakes)
        }

        for i in range(all_actions['agent_0'].shape[0]):
            # env.render()
            actions = {
                agent: agent_actions[i] for agent, agent_actions in all_actions.items()
            }
            obs, reward, done, info = env.step(actions)
            env.reset(done['__all__'])
            env.check_consistency()
def get_env(args: argparse.Namespace, observation_fn: ObservationFunction, device: str):
    if args.env is None:
        raise ValueError('args.env is None.')
    render_args = {
        'size': args.render_window_size,
        'num_rows': args.render_rows,
        'num_cols': args.render_cols,
    }
    # Pre-process map name (kind of a hack to help identify which "small2" map we're
    # referring to in the arguments)
    args.env_map = [f'{args.env}-{m}' for m in args.env_map]
    if args.env == 'snake':
        args.env_map = [f'snake-{m}' for m in args.env_map]
        env = Slither(num_envs=args.n_envs, num_agents=args.n_agents, food_on_death_prob=args.food_on_death,
                      height=args.height, width=args.width, device=device, render_args=render_args,
                      boost=args.boost,
                      boost_cost_prob=args.boost_cost, food_rate=args.food_rate,
                      respawn_mode=args.respawn_mode, food_mode=args.food_mode, observation_fn=observation_fn,
                      reward_on_death=args.reward_on_death, agent_colours=args.colour_mode)
    elif args.env == 'laser':
        if len(args.env_map) == 1:
            if args.env_map[0] == 'random':
                # Generate n_maps random mazes to select from at random during each map reset
                map_generator = Random(args.n_respawns, args.height, args.width, args.maze_complexity,
                                       args.maze_density, args.device)
            elif args.env_map[0] == 'from_file':
                maps = maps_from_file(args.pathing_file, args.respawn_file, args.device, args.n_maps)
                map_generator = MapPool(maps)
            else:
                # Single fixed map
                map_generator = FixedMapGenerator(parse_mapstring(args.env_map[0]), device)
        else:
            fixed_maps = [parse_mapstring(m) for m in args.env_map]
            map_generator = MapPool(fixed_maps)

        env = LaserTag(num_envs=args.n_envs, num_agents=args.n_agents, height=args.height, width=args.width,
                       observation_fn=observation_fn, colour_mode=args.colour_mode,
                       map_generator=map_generator, device=device, render_args=render_args, strict=args.strict)
    elif args.env == 'harvest':
        if len(args.env_map) == 1:
            if args.env_map[0] == 'random':
                # Generate n_maps random mazes to select from at random during each map reset
                map_generator = Random(args.n_respawns, args.height, args.width, args.maze_complexity,
                                       args.maze_density, args.device)
            elif args.env_map[0] == 'from_file':
                maps = maps_from_file(args.pathing_file, args.respawn_file, args.device, args.n_maps,
                                      other_tensors={'harvest': args.harvest_file})
                map_generator = MapPool(maps)
            else:
                # Single fixed map
                map_generator = FixedMapGenerator(parse_mapstring(args.env_map[0]), device)
        else:
            fixed_maps = [parse_mapstring(m) for m in args.env_map]
            map_generator = MapPool(fixed_maps)

        env = Harvest(num_envs=args.n_envs, num_agents=args.n_agents, height=args.height, width=args.width,
                      observation_fn=observation_fn, colour_mode=args.colour_mode, refresh_rate=args.harvest_refresh,
                      map_generator=map_generator, device=device, render_args=render_args, strict=args.strict)

    elif args.env == 'asymmetric':
        raise NotImplementedError
    else:
        raise ValueError('Unrecognised environment')

    return env
    def test_cant_boost_until_size_4(self):
        # Create a size 3 snake and try boosting with it
        env = Slither(num_envs=1, num_agents=2, height=size, width=size, manual_setup=True, boost=True)
        env.foods[:, 0, 1, 1] = 1
        # Snake 1
        env.agents[0, 0, 5, 5] = 1
        env.bodies[0, 0, 5, 5] = 3
        env.bodies[0, 0, 4, 5] = 2
        env.bodies[0, 0, 4, 4] = 1
        # Snake 2
        env.agents[1, 0, 8, 7] = 1
        env.bodies[1, 0, 8, 7] = 3
        env.bodies[1, 0, 8, 8] = 2
        env.bodies[1, 0, 8, 9] = 1

        # Get orientations manually
        _envs = torch.cat([
            env.foods.repeat_interleave(env.num_agents, dim=0),
            env.agents,
            env.bodies
        ], dim=1)

        env.orientations = determine_orientations(_envs)

        expected_head_positions = torch.tensor([
            [6, 5],
            [6, 4],
            [5, 4],
        ])

        all_actions = {
            'agent_0': torch.tensor([4, 1, 2]).unsqueeze(1).long().to(DEFAULT_DEVICE),
            'agent_1': torch.tensor([0, 1, 3]).unsqueeze(1).long().to(DEFAULT_DEVICE),
        }

        print_or_render(env)

        for i in range(all_actions['agent_0'].shape[0]):
            actions = {
                agent: agent_actions[i] for agent, agent_actions in all_actions.items()
            }

            obs, rewards, dones, info = env.step(actions)

            env.reset(dones['__all__'])

            env.check_consistency()

            for i_agent in range(env.num_agents):
                _env = torch.cat([
                    env.foods,
                    env.agents[i_agent].unsqueeze(0),
                    env.bodies[i_agent].unsqueeze(0)
                ], dim=1)

                head_position = torch.tensor([
                    head(_env)[0, 0].flatten().argmax() // size, head(_env)[0, 0].flatten().argmax() % size
                ])

                if i_agent == 0:
                    self.assertTrue(torch.equal(expected_head_positions[i], head_position))

            print_or_render(env)