Esempio n. 1
0
    def forward_tail(self, core_output, with_action_distribution=False):
        core_outputs = core_output.chunk(len(self.cores), dim=1)

        # first core output corresponds to the actor
        action_distribution_params, action_distribution = self.action_parameterization(
            core_outputs[0])
        # for non-trivial action spaces it is faster to do these together
        actions, log_prob_actions = sample_actions_log_probs(
            action_distribution)

        # second core output corresponds to the critic
        values = self.critic_linear(core_outputs[1])

        result = AttrDict(
            dict(
                actions=actions,
                action_logits=action_distribution_params,
                log_prob_actions=log_prob_actions,
                values=values,
            ))

        if with_action_distribution:
            result.action_distribution = action_distribution

        return result
Esempio n. 2
0
def get_obs_shape(obs_space):
    obs_shape = AttrDict()
    if hasattr(obs_space, 'spaces'):
        for key, space in obs_space.spaces.items():
            obs_shape[key] = space.shape
    else:
        obs_shape.obs = obs_space.shape

    return obs_shape
Esempio n. 3
0
    def _handle_policy_steps(self, timing):
        with torch.no_grad():
            with timing.add_time('deserialize'):
                observations = AttrDict()
                rnn_states = []

                traj_tensors = self.shared_buffers.tensors_individual_transitions
                for request in self.requests:
                    actor_idx, split_idx, request_data = request

                    for env_idx, agent_idx, traj_buffer_idx, rollout_step in request_data:
                        index = actor_idx, split_idx, env_idx, agent_idx, traj_buffer_idx, rollout_step
                        dict_of_lists_append(observations, traj_tensors['obs'], index)
                        rnn_states.append(traj_tensors['rnn_states'][index])
                        self.total_num_samples += 1

            with timing.add_time('stack'):
                for key, x in observations.items():
                    observations[key] = torch.stack(x)
                rnn_states = torch.stack(rnn_states)
                num_samples = rnn_states.shape[0]

            with timing.add_time('obs_to_device'):
                for key, x in observations.items():
                    device, dtype = self.actor_critic.device_and_type_for_input_tensor(key)
                    observations[key] = x.to(device).type(dtype)
                rnn_states = rnn_states.to(self.device).float()

            with timing.add_time('forward'):
                policy_outputs = self.actor_critic(observations, rnn_states)

            with timing.add_time('to_cpu'):
                for key, output_value in policy_outputs.items():
                    policy_outputs[key] = output_value.cpu()

            with timing.add_time('format_outputs'):
                policy_outputs.policy_version = torch.empty([num_samples]).fill_(self.latest_policy_version)

                # concat all tensors into a single tensor for performance
                output_tensors = []
                for policy_output in self.shared_buffers.policy_outputs:
                    tensor_name = policy_output.name
                    output_value = policy_outputs[tensor_name].float()
                    if len(output_value.shape) == 1:
                        output_value.unsqueeze_(dim=1)
                    output_tensors.append(output_value)

                output_tensors = torch.cat(output_tensors, dim=1)

            with timing.add_time('postprocess'):
                self._enqueue_policy_outputs(self.requests, output_tensors)

        self.requests = []
Esempio n. 4
0
    def print_stats(self, fps, sample_throughput, total_env_steps):
        fps_str = []
        for interval, fps_value in zip(self.avg_stats_intervals, fps):
            fps_str.append(
                f'{int(interval * self.report_interval)} sec: {fps_value:.1f}')
        fps_str = f'({", ".join(fps_str)})'

        samples_per_policy = ', '.join(
            [f'{p}: {s:.1f}' for p, s in sample_throughput.items()])

        lag_stats = self.policy_lag[0]
        lag = AttrDict()
        for key in ['min', 'avg', 'max']:
            lag[key] = lag_stats.get(f'version_diff_{key}', -1)
        policy_lag_str = f'min: {lag.min:.1f}, avg: {lag.avg:.1f}, max: {lag.max:.1f}'

        log.debug(
            'Fps is %s. Total num frames: %d. Throughput: %s. Samples: %d. Policy #0 lag: (%s)',
            fps_str,
            total_env_steps,
            samples_per_policy,
            sum(self.samples_collected),
            policy_lag_str,
        )

        if 'reward' in self.policy_avg_stats:
            policy_reward_stats = []
            for policy_id in range(self.cfg.num_policies):
                reward_stats = self.policy_avg_stats['reward'][policy_id]
                if len(reward_stats) > 0:
                    policy_reward_stats.append(
                        (policy_id, f'{np.mean(reward_stats):.3f}'))
            log.debug('Avg episode reward: %r', policy_reward_stats)
Esempio n. 5
0
def load_from_checkpoint(cfg):
    filename = cfg_file(cfg)
    if not os.path.isfile(filename):
        raise Exception(
            f'Could not load saved parameters for experiment {cfg.experiment}')

    with open(filename, 'r') as json_file:
        json_params = json.load(json_file)
        log.warning('Loading existing experiment configuration from %s',
                    filename)
        loaded_cfg = AttrDict(json_params)

    # override the parameters in config file with values passed from command line
    for key, value in cfg.cli_args.items():
        if key in loaded_cfg and loaded_cfg[key] != value:
            log.debug(
                'Overriding arg %r with value %r passed from command line',
                key, value)
            loaded_cfg[key] = value

    # incorporate extra CLI parameters that were not present in JSON file
    for key, value in vars(cfg).items():
        if key not in loaded_cfg:
            log.debug(
                'Adding new argument %r=%r that is not in the saved config file!',
                key, value)
            loaded_cfg[key] = value

    return loaded_cfg
Esempio n. 6
0
def maybe_load_from_checkpoint(cfg):
    filename = cfg_file(cfg)
    if not os.path.isfile(filename):
        log.warning(
            'Saved parameter configuration for experiment %s not found!',
            cfg.experiment)
        log.warning('Starting experiment from scratch!')
        return AttrDict(vars(cfg))

    return load_from_checkpoint(cfg)
Esempio n. 7
0
    def forward_tail(self, core_output, with_action_distribution=False):
        values = self.critic_linear(core_output)

        action_distribution_params, action_distribution = self.action_parameterization(core_output)

        # for non-trivial action spaces it is faster to do these together
        actions, log_prob_actions = sample_actions_log_probs(action_distribution)

        result = AttrDict(dict(
            actions=actions,
            action_logits=action_distribution_params,  # perhaps `action_logits` is not the best name here since we now support continuous actions
            log_prob_actions=log_prob_actions,
            values=values,
        ))

        if with_action_distribution:
            result.action_distribution = action_distribution

        return result
Esempio n. 8
0
    def init(self):
        """
        Actually instantiate the env instances.
        Also creates ActorState objects that hold the state of individual actors in (potentially) multi-agent envs.
        """

        for env_i in range(self.num_envs):
            vector_idx = self.split_idx * self.num_envs + env_i

            # global env id within the entire system
            env_id = self.worker_idx * self.cfg.num_envs_per_worker + vector_idx

            env_config = AttrDict(
                worker_index=self.worker_idx,
                vector_index=vector_idx,
                env_id=env_id,
            )

            # log.info('Creating env %r... %d-%d-%d', env_config, self.worker_idx, self.split_idx, env_i)
            env = make_env_func(self.cfg, env_config=env_config)

            env.seed(env_id)
            self.envs.append(env)

            actor_states_env, episode_rewards_env = [], []
            for agent_idx in range(self.num_agents):
                agent_traj_tensors = self.traj_tensors.index(
                    (env_i, agent_idx))
                actor_state = ActorState(
                    self.cfg,
                    env,
                    self.worker_idx,
                    self.split_idx,
                    env_i,
                    agent_idx,
                    agent_traj_tensors,
                    self.num_traj_buffers,
                    self.policy_outputs,
                    self.policy_output_tensors[env_i, agent_idx],
                    self.pbt_reward_shaping,
                    self.policy_mgr,
                )
                actor_states_env.append(actor_state)
                episode_rewards_env.append(0.0)

            self.actor_states.append(actor_states_env)
            self.episode_rewards.append(episode_rewards_env)
Esempio n. 9
0
def test_env_performance(make_env, env_type, verbose=False):
    t = Timing()
    with t.timeit('init'):
        env = make_env(AttrDict({'worker_index': 0, 'vector_index': 0}))
        total_num_frames, frames = 10000, 0

    with t.timeit('first_reset'):
        env.reset()

    t.reset = t.step = 1e-9
    num_resets = 0
    with t.timeit('experience'):
        while frames < total_num_frames:
            done = False

            start_reset = time.time()
            env.reset()

            t.reset += time.time() - start_reset
            num_resets += 1

            while not done and frames < total_num_frames:
                start_step = time.time()
                if verbose:
                    env.render()
                    time.sleep(1.0 / 40)

                obs, rew, done, info = env.step(env.action_space.sample())
                if verbose:
                    log.info('Received reward %.3f', rew)

                t.step += time.time() - start_step
                frames += num_env_steps([info])

    fps = total_num_frames / t.experience
    log.debug('%s performance:', env_type)
    log.debug('Took %.3f sec to collect %d frames on one CPU, %.1f FPS',
              t.experience, total_num_frames, fps)
    log.debug('Avg. reset time %.3f s', t.reset / num_resets)
    log.debug('Timing: %s', t)
    env.close()
Esempio n. 10
0
    def doom_multiagent(make_multi_env, worker_index, num_steps=1000):
        env_config = AttrDict({
            'worker_index': worker_index,
            'vector_index': 0,
            'safe_init': False
        })
        multi_env = make_multi_env(env_config)

        obs = multi_env.reset()

        visualize = False
        start = time.time()

        for i in range(num_steps):
            actions = [multi_env.action_space.sample()] * len(obs)
            obs, rew, dones, infos = multi_env.step(actions)

            if visualize:
                multi_env.render()

            if i % 100 == 0 or any(dones):
                log.info('Rew %r done %r info %r', rew, dones, infos)

            if all(dones):
                multi_env.reset()

        took = time.time() - start
        log.info('Took %.3f seconds for %d steps', took, num_steps)
        log.info('Server steps per second: %.1f', num_steps / took)
        log.info('Observations fps: %.1f',
                 num_steps * multi_env.num_agents / took)
        log.info(
            'Environment fps: %.1f',
            num_steps * multi_env.num_agents * multi_env.skip_frames / took)

        multi_env.close()
Esempio n. 11
0
def enjoy(cfg, max_num_frames=1e9):
    cfg = load_from_checkpoint(cfg)

    render_action_repeat = cfg.render_action_repeat if cfg.render_action_repeat is not None else cfg.env_frameskip
    if render_action_repeat is None:
        log.warning('Not using action repeat!')
        render_action_repeat = 1
    log.debug('Using action repeat %d during evaluation', render_action_repeat)

    cfg.env_frameskip = 1  # for evaluation
    cfg.num_envs = 1

    def make_env_func(env_config):
        return create_env(cfg.env, cfg=cfg, env_config=env_config)

    env = make_env_func(AttrDict({'worker_index': 0, 'vector_index': 0}))
    # env.seed(0)

    is_multiagent = is_multiagent_env(env)
    if not is_multiagent:
        env = MultiAgentWrapper(env)

    if hasattr(env.unwrapped, 'reset_on_init'):
        # reset call ruins the demo recording for VizDoom
        env.unwrapped.reset_on_init = False

    actor_critic = create_actor_critic(cfg, env.observation_space,
                                       env.action_space)

    device = torch.device('cpu' if cfg.device == 'cpu' else 'cuda')
    actor_critic.model_to_device(device)

    policy_id = cfg.policy_index
    checkpoints = LearnerWorker.get_checkpoints(
        LearnerWorker.checkpoint_dir(cfg, policy_id))
    checkpoint_dict = LearnerWorker.load_checkpoint(checkpoints, device)
    actor_critic.load_state_dict(checkpoint_dict['model'])

    episode_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)]
    true_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)]
    num_frames = 0

    last_render_start = time.time()

    def max_frames_reached(frames):
        return max_num_frames is not None and frames > max_num_frames

    obs = env.reset()
    rnn_states = torch.zeros(
        [env.num_agents, get_hidden_size(cfg)],
        dtype=torch.float32,
        device=device)
    episode_reward = np.zeros(env.num_agents)
    finished_episode = [False] * env.num_agents

    with torch.no_grad():
        while not max_frames_reached(num_frames):
            obs_torch = AttrDict(transform_dict_observations(obs))
            for key, x in obs_torch.items():
                obs_torch[key] = torch.from_numpy(x).to(device).float()

            policy_outputs = actor_critic(obs_torch,
                                          rnn_states,
                                          with_action_distribution=True)

            # sample actions from the distribution by default
            actions = policy_outputs.actions

            action_distribution = policy_outputs.action_distribution
            if isinstance(action_distribution, ContinuousActionDistribution):
                if not cfg.continuous_actions_sample:  # TODO: add similar option for discrete actions
                    actions = action_distribution.means

            actions = actions.cpu().numpy()

            rnn_states = policy_outputs.rnn_states

            for _ in range(render_action_repeat):
                if not cfg.no_render:
                    target_delay = 1.0 / cfg.fps if cfg.fps > 0 else 0
                    current_delay = time.time() - last_render_start
                    time_wait = target_delay - current_delay

                    if time_wait > 0:
                        # log.info('Wait time %.3f', time_wait)
                        time.sleep(time_wait)

                    last_render_start = time.time()
                    env.render()

                obs, rew, done, infos = env.step(actions)

                episode_reward += rew
                num_frames += 1

                for agent_i, done_flag in enumerate(done):
                    if done_flag:
                        finished_episode[agent_i] = True
                        episode_rewards[agent_i].append(
                            episode_reward[agent_i])
                        true_rewards[agent_i].append(infos[agent_i].get(
                            'true_reward', episode_reward[agent_i]))
                        log.info(
                            'Episode finished for agent %d at %d frames. Reward: %.3f, true_reward: %.3f',
                            agent_i, num_frames, episode_reward[agent_i],
                            true_rewards[agent_i][-1])
                        rnn_states[agent_i] = torch.zeros(
                            [get_hidden_size(cfg)],
                            dtype=torch.float32,
                            device=device)
                        episode_reward[agent_i] = 0

                # if episode terminated synchronously for all agents, pause a bit before starting a new one
                if all(done):
                    if not cfg.no_render:
                        env.render()
                    time.sleep(0.05)

                if all(finished_episode):
                    finished_episode = [False] * env.num_agents
                    avg_episode_rewards_str, avg_true_reward_str = '', ''
                    for agent_i in range(env.num_agents):
                        avg_rew = np.mean(episode_rewards[agent_i])
                        avg_true_rew = np.mean(true_rewards[agent_i])
                        if not np.isnan(avg_rew):
                            if avg_episode_rewards_str:
                                avg_episode_rewards_str += ', '
                            avg_episode_rewards_str += f'#{agent_i}: {avg_rew:.3f}'
                        if not np.isnan(avg_true_rew):
                            if avg_true_reward_str:
                                avg_true_reward_str += ', '
                            avg_true_reward_str += f'#{agent_i}: {avg_true_rew:.3f}'

                    log.info('Avg episode rewards: %s, true rewards: %s',
                             avg_episode_rewards_str, avg_true_reward_str)
                    log.info(
                        'Avg episode reward: %.3f, avg true_reward: %.3f',
                        np.mean([
                            np.mean(episode_rewards[i])
                            for i in range(env.num_agents)
                        ]),
                        np.mean([
                            np.mean(true_rewards[i])
                            for i in range(env.num_agents)
                        ]))

                # VizDoom multiplayer stuff
                # for player in [1, 2, 3, 4, 5, 6, 7, 8]:
                #     key = f'PLAYER{player}_FRAGCOUNT'
                #     if key in infos[0]:
                #         log.debug('Score for player %d: %r', player, infos[0][key])

    env.close()

    return ExperimentStatus.SUCCESS, np.mean(episode_rewards)
Esempio n. 12
0
    def sample(self, proc_idx):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        if self.cfg.sampler_worker_gpus:
            set_gpus_for_process(
                proc_idx,
                num_gpus_per_process=1,
                process_type='sampler_proc',
                gpu_mask=self.cfg.sampler_worker_gpus,
            )

        timing = Timing()

        from threadpoolctl import threadpool_limits
        with threadpool_limits(limits=1, user_api=None):
            if self.cfg.set_workers_cpu_affinity:
                set_process_cpu_affinity(proc_idx, self.cfg.num_workers)

            initial_cpu_affinity = psutil.Process().cpu_affinity(
            ) if platform != 'darwin' else None
            psutil.Process().nice(10)

            with timing.timeit('env_init'):
                envs = []
                env_key = ['env' for _ in range(self.cfg.num_envs_per_worker)]

                for env_idx in range(self.cfg.num_envs_per_worker):
                    global_env_id = proc_idx * self.cfg.num_envs_per_worker + env_idx
                    env_config = AttrDict(worker_index=proc_idx,
                                          vector_index=env_idx,
                                          env_id=global_env_id)

                    env = make_env_func(cfg=self.cfg, env_config=env_config)
                    log.debug(
                        'CPU affinity after create_env: %r',
                        psutil.Process().cpu_affinity()
                        if platform != 'darwin' else 'MacOS - None')
                    env.seed(global_env_id)
                    envs.append(env)

                    # this is to track the performance for individual DMLab levels
                    if hasattr(env.unwrapped, 'level_name'):
                        env_key[env_idx] = env.unwrapped.level_name

                episode_length = [0 for _ in envs]
                episode_lengths = [deque([], maxlen=20) for _ in envs]

            # sample a lot of random actions once, otherwise it is pretty slow in Python
            total_random_actions = 500
            actions = [[
                env.action_space.sample() for _ in range(env.num_agents)
            ] for _ in range(total_random_actions)]
            action_i = 0

            try:
                with timing.timeit('first_reset'):
                    for env_idx, env in enumerate(envs):
                        env.reset()
                        log.info('Process %d finished resetting %d/%d envs',
                                 proc_idx, env_idx + 1, len(envs))

                    self.report_queue.put(
                        dict(proc_idx=proc_idx, finished_reset=True))

                self.start_event.wait()

                with timing.timeit('work'):
                    last_report = last_report_frames = total_env_frames = 0
                    while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker:
                        for env_idx, env in enumerate(envs):
                            with timing.add_time(f'{env_key[env_idx]}.step'):
                                obs, rewards, dones, infos = env.step(
                                    actions[action_i])
                                action_i = (action_i +
                                            1) % total_random_actions

                            num_frames = sum(
                                [info.get('num_frames', 1) for info in infos])
                            total_env_frames += num_frames
                            episode_length[env_idx] += num_frames

                            if all(dones):
                                episode_lengths[env_idx].append(
                                    episode_length[env_idx] / env.num_agents)
                                episode_length[env_idx] = 0

                        with timing.add_time('report'):
                            now = time.time()
                            if now - last_report > self.report_every_sec:
                                last_report = now
                                frames_since_last_report = total_env_frames - last_report_frames
                                last_report_frames = total_env_frames
                                self.report_queue.put(
                                    dict(proc_idx=proc_idx,
                                         env_frames=frames_since_last_report))

                                if proc_idx == 0:
                                    log.debug('Memory usage: %.4f Mb',
                                              memory_consumption_mb())

                # Extra check to make sure cpu affinity is preserved throughout the execution.
                # I observed weird effect when some environments tried to alter affinity of the current process, leading
                # to decreased performance.
                # This can be caused by some interactions between deep learning libs, OpenCV, MKL, OpenMP, etc.
                # At least user should know about it if this is happening.
                cpu_affinity = psutil.Process().cpu_affinity(
                ) if platform != 'darwin' else None
                assert initial_cpu_affinity == cpu_affinity, \
                    f'Worker CPU affinity was changed from {initial_cpu_affinity} to {cpu_affinity}!' \
                    f'This can significantly affect performance!'

            except:
                log.exception('Unknown exception')
                log.error('Unknown exception in worker %d, terminating...',
                          proc_idx)
                self.report_queue.put(dict(proc_idx=proc_idx, crash=True))

            time.sleep(proc_idx * 0.01 + 0.01)
            log.info('Process %d finished sampling. Timing: %s', proc_idx,
                     timing)

            for env_idx, env in enumerate(envs):
                if len(episode_lengths[env_idx]) > 0:
                    log.warning('Level %s avg episode len %d',
                                env_key[env_idx],
                                np.mean(episode_lengths[env_idx]))

            for env in envs:
                env.close()
Esempio n. 13
0
    def sample(self, proc_idx):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        timing = Timing()

        psutil.Process().nice(10)

        num_envs = len(DMLAB30_LEVELS_THAT_USE_LEVEL_CACHE)
        assert self.cfg.num_workers % num_envs == 0, f'should have an integer number of workers per env, e.g. {1 * num_envs}, {2 * num_envs}, etc...'
        assert self.cfg.num_envs_per_worker == 1, 'use populate_cache with 1 env per worker'

        with timing.timeit('env_init'):
            env_key = 'env'
            env_desired_num_levels = 0

            global_env_id = proc_idx * self.cfg.num_envs_per_worker
            env_config = AttrDict(worker_index=proc_idx, vector_index=0, env_id=global_env_id)
            env = create_env(self.cfg.env, cfg=self.cfg, env_config=env_config)
            env.seed(global_env_id)

            # this is to track the performance for individual DMLab levels
            if hasattr(env.unwrapped, 'level_name'):
                env_key = env.unwrapped.level_name
                env_level = env.unwrapped.level

                approx_num_episodes_per_1b_frames = DMLAB30_APPROX_NUM_EPISODES_PER_BILLION_FRAMES[env_key]
                num_billions = DESIRED_TRAINING_LENGTH / int(1e9)
                num_workers_for_env = self.cfg.num_workers // num_envs
                env_desired_num_levels = int((approx_num_episodes_per_1b_frames * num_billions) / num_workers_for_env)

                env_num_levels_generated = len(dmlab_level_cache.DMLAB_GLOBAL_LEVEL_CACHE[0].all_seeds[env_level]) // num_workers_for_env

                log.warning('Worker %d (env %s) generated %d/%d levels!', proc_idx, env_key, env_num_levels_generated, env_desired_num_levels)
                time.sleep(4)

            env.reset()
            env_uses_level_cache = env.unwrapped.env_uses_level_cache

            self.report_queue.put(dict(proc_idx=proc_idx, finished_reset=True))

        self.start_event.wait()

        try:
            with timing.timeit('work'):
                last_report = last_report_frames = total_env_frames = 0
                while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker:
                    action = env.action_space.sample()
                    with timing.add_time(f'{env_key}.step'):
                        env.step(action)

                    total_env_frames += 1

                    with timing.add_time(f'{env_key}.reset'):
                        env.reset()
                        env_num_levels_generated += 1
                        log.debug('Env %s done %d/%d resets', env_key, env_num_levels_generated, env_desired_num_levels)

                    if env_num_levels_generated >= env_desired_num_levels:
                        log.debug('%s finished %d/%d resets, sleeping...', env_key, env_num_levels_generated, env_desired_num_levels)
                        time.sleep(30)  # free up CPU time for other envs

                    # if env does not use level cache, there is no need to run it
                    # let other workers proceed
                    if not env_uses_level_cache:
                        log.debug('Env %s does not require cache, sleeping...', env_key)
                        time.sleep(200)

                    with timing.add_time('report'):
                        now = time.time()
                        if now - last_report > self.report_every_sec:
                            last_report = now
                            frames_since_last_report = total_env_frames - last_report_frames
                            last_report_frames = total_env_frames
                            self.report_queue.put(dict(proc_idx=proc_idx, env_frames=frames_since_last_report))

                            if get_free_disk_space_mb(self.cfg) < 3 * 1024:
                                log.error('Not enough disk space! %d', get_free_disk_space_mb(self.cfg))
                                time.sleep(200)
        except:
            log.exception('Unknown exception')
            log.error('Unknown exception in worker %d, terminating...', proc_idx)
            self.report_queue.put(dict(proc_idx=proc_idx, crash=True))

        time.sleep(proc_idx * 0.1 + 0.1)
        log.info('Process %d finished sampling. Timing: %s', proc_idx, timing)

        env.close()
Esempio n. 14
0
def multi_agent_match(policy_indices, max_num_episodes=int(1e9), max_num_frames=1e10):
    log.debug('Starting eval process with policies %r', policy_indices)
    for i, rival in enumerate(RIVALS):
        rival.policy_index = policy_indices[i]

    curr_dir = os.path.dirname(os.path.abspath(__file__))
    evaluation_filename = join(curr_dir, f'eval_{"vs".join([str(pi) for pi in policy_indices])}.txt')
    with open(evaluation_filename, 'w') as fobj:
        fobj.write('start\n')

    common_config = RIVALS[0].cfg

    render_action_repeat = common_config.render_action_repeat if common_config.render_action_repeat is not None else common_config.env_frameskip
    if render_action_repeat is None:
        log.warning('Not using action repeat!')
        render_action_repeat = 1
    log.debug('Using action repeat %d during evaluation', render_action_repeat)

    common_config.env_frameskip = 1  # for evaluation
    common_config.num_envs = 1
    common_config.timelimit = 4.0  # for faster evaluation

    def make_env_func(env_config):
        return create_env(ENV_NAME, cfg=common_config, env_config=env_config)

    env = make_env_func(AttrDict({'worker_index': 0, 'vector_index': 0}))
    env.seed(0)

    is_multiagent = is_multiagent_env(env)
    if not is_multiagent:
        env = MultiAgentWrapper(env)
    else:
        assert env.num_agents == len(RIVALS)

    device = torch.device('cuda')
    for rival in RIVALS:
        rival.actor_critic = create_actor_critic(rival.cfg, env.observation_space, env.action_space)
        rival.actor_critic.model_to_device(device)

        policy_id = rival.policy_index
        checkpoints = LearnerWorker.get_checkpoints(LearnerWorker.checkpoint_dir(rival.cfg, policy_id))
        checkpoint_dict = LearnerWorker.load_checkpoint(checkpoints, device)
        rival.actor_critic.load_state_dict(checkpoint_dict['model'])

    episode_rewards = []
    num_frames = 0

    last_render_start = time.time()

    def max_frames_reached(frames):
        return max_num_frames is not None and frames > max_num_frames

    wins = [0 for _ in RIVALS]
    ties = 0
    frag_differences = []

    with torch.no_grad():
        for _ in range(max_num_episodes):
            obs = env.reset()
            obs_dict_torch = dict()

            done = [False] * len(obs)
            for rival in RIVALS:
                rival.rnn_states = torch.zeros([1, rival.cfg.hidden_size], dtype=torch.float32, device=device)

            episode_reward = 0
            prev_frame = time.time()

            while True:
                actions = []
                for i, obs_dict in enumerate(obs):
                    for key, x in obs_dict.items():
                        obs_dict_torch[key] = torch.from_numpy(x).to(device).float().view(1, *x.shape)

                    rival = RIVALS[i]
                    policy_outputs = rival.actor_critic(obs_dict_torch, rival.rnn_states)
                    rival.rnn_states = policy_outputs.rnn_states
                    actions.append(policy_outputs.actions[0].cpu().numpy())

                for _ in range(render_action_repeat):
                    if not NO_RENDER:
                        target_delay = 1.0 / FPS if FPS > 0 else 0
                        current_delay = time.time() - last_render_start
                        time_wait = target_delay - current_delay

                        if time_wait > 0:
                            # log.info('Wait time %.3f', time_wait)
                            time.sleep(time_wait)

                        last_render_start = time.time()
                        env.render()

                    obs, rew, done, infos = env.step(actions)
                    if all(done):
                        log.debug('Finished episode!')

                        frag_diff = infos[0]['PLAYER1_FRAGCOUNT'] - infos[0]['PLAYER2_FRAGCOUNT']
                        if frag_diff > 0:
                            wins[0] += 1
                        elif frag_diff < 0:
                            wins[1] += 1
                        else:
                            ties += 1

                        frag_differences.append(frag_diff)
                        avg_frag_diff = np.mean(frag_differences)

                        report = f'wins: {wins}, ties: {ties}, avg_frag_diff: {avg_frag_diff}'
                        with open(evaluation_filename, 'a') as fobj:
                            fobj.write(report + '\n')

                    # log.info('%d:%d', infos[0]['PLAYER1_FRAGCOUNT'], infos[0]['PLAYER2_FRAGCOUNT'])

                    episode_reward += np.mean(rew)
                    num_frames += 1

                    if num_frames % 100 == 0:
                        log.debug('%.1f', render_action_repeat / (time.time() - prev_frame))
                    prev_frame = time.time()

                    if all(done):
                        log.info('Episode finished at %d frames', num_frames)
                        break

                if all(done) or max_frames_reached(num_frames):
                    break

            if not NO_RENDER:
                env.render()
            time.sleep(0.01)

            episode_rewards.append(episode_reward)
            last_episodes = episode_rewards[-100:]
            avg_reward = sum(last_episodes) / len(last_episodes)
            log.info(
                'Episode reward: %f, avg reward for %d episodes: %f', episode_reward, len(last_episodes), avg_reward,
            )

            if max_frames_reached(num_frames):
                break

    env.close()