Example #1
0
    def forward_tail(self, core_output, with_action_distribution=False):

        self.termination_prob = self.termination(core_output)
        self.termination_mask = torch.where(
            self.termination_prob > torch.rand_like(self.termination_prob),
            torch.ones(1, device=self.termination_prob.device),
            torch.zeros(1, device=self.termination_prob.device))

        values = self.critic_linear(core_output)
        action_distribution_params, action_distribution = self.action_parameterization(core_output)

        # for non-trivial action spaces it is faster to do these together
        actions, log_prob_actions = sample_actions_log_probs(action_distribution)

        # perhaps `action_logits` is not the best name here since we now support continuous actions
        result = AttrDict(
            dict(
                actions=actions,  # (B * O) x (num_actions/D)
                # B x num_action_logits x O -> (B * O) x num_action_logits
                action_logits=action_distribution_params.reshape(-1,
                                                                 action_distribution.num_actions),
                log_prob_actions=log_prob_actions,  # (B * O) x 1
                values=values,
                termination_prob=self.termination_prob,
                termination_mask=self.termination_mask,
            ))

        if with_action_distribution:
            result.action_distribution = action_distribution

        return result
Example #2
0
    def forward_tail(self, core_output, with_action_distribution=False):
        core_outputs = core_output.chunk(len(self.cores), dim=1)

        # first core output corresponds to the actor
        action_distribution_params, action_distribution = self.action_parameterization(
            core_outputs[0])
        # for non-trivial action spaces it is faster to do these together
        actions, log_prob_actions = sample_actions_log_probs(
            action_distribution)

        # second core output corresponds to the critic
        values = self.critic_linear(core_outputs[1])

        result = AttrDict(
            dict(
                actions=actions,
                action_logits=action_distribution_params,
                log_prob_actions=log_prob_actions,
                values=values,
            ))

        if with_action_distribution:
            result.action_distribution = action_distribution

        return result
def get_obs_shape(obs_space):
    obs_shape = AttrDict()
    if hasattr(obs_space, 'spaces'):
        for key, space in obs_space.spaces.items():
            obs_shape[key] = space.shape
    else:
        obs_shape.obs = obs_space.shape

    return obs_shape
    def _handle_policy_steps(self, timing):
        with torch.no_grad():
            with timing.add_time('deserialize'):
                observations = AttrDict()
                rnn_states = []

                traj_tensors = self.shared_buffers.tensors_individual_transitions
                for request in self.requests:
                    actor_idx, split_idx, request_data = request

                    for env_idx, agent_idx, traj_buffer_idx, rollout_step in request_data:
                        index = actor_idx, split_idx, env_idx, agent_idx, traj_buffer_idx, rollout_step
                        dict_of_lists_append(observations, traj_tensors['obs'],
                                             index)
                        rnn_states.append(traj_tensors['rnn_states'][index])
                        self.total_num_samples += 1

            with timing.add_time('stack'):
                for key, x in observations.items():
                    observations[key] = torch.stack(x)
                rnn_states = torch.stack(rnn_states)
                num_samples = rnn_states.shape[0]

            with timing.add_time('obs_to_device'):
                for key, x in observations.items():
                    device, dtype = self.actor_critic.device_and_type_for_input_tensor(
                        key)
                    observations[key] = x.to(device).type(dtype)
                rnn_states = rnn_states.to(self.device).float()

            with timing.add_time('forward'):
                policy_outputs = self.actor_critic(observations, rnn_states)

            with timing.add_time('to_cpu'):
                for key, output_value in policy_outputs.items():
                    policy_outputs[key] = output_value.cpu()

            with timing.add_time('format_outputs'):
                policy_outputs.policy_version = torch.empty(
                    [num_samples]).fill_(self.latest_policy_version)

                # concat all tensors into a single tensor for performance
                output_tensors = []
                for policy_output in self.shared_buffers.policy_outputs:
                    tensor_name = policy_output.name
                    output_value = policy_outputs[tensor_name].float()
                    if len(output_value.shape) == 1:
                        output_value.unsqueeze_(dim=1)
                    output_tensors.append(output_value)

                output_tensors = torch.cat(output_tensors, dim=1)

            with timing.add_time('postprocess'):
                self._enqueue_policy_outputs(self.requests, output_tensors)

        self.requests = []
Example #5
0
    def _prepare_train_buffer(self, rollouts, macro_batch_size, timing):
        trajectories = [AttrDict(r['t']) for r in rollouts]

        with timing.add_time('buffers'):
            buffer = AttrDict()

            # by the end of this loop the buffer is a dictionary containing lists of numpy arrays
            for i, t in enumerate(trajectories):
                for key, x in t.items():
                    if key not in buffer:
                        buffer[key] = []
                    buffer[key].append(x)

            # convert lists of dict observations to a single dictionary of lists
            for key, x in buffer.items():
                if isinstance(x[0], (dict, OrderedDict)):
                    buffer[key] = list_of_dicts_to_dict_of_lists(x)

        if not self.cfg.with_vtrace:
            with timing.add_time('calc_gae'):
                buffer = self._calculate_gae(buffer)

        with timing.add_time('batching'):
            # concatenate rollouts from different workers into a single batch efficiently
            # that is, if we already have memory for the buffers allocated, we can just copy the data into
            # existing cached tensors instead of creating new ones. This is a performance optimization.
            use_pinned_memory = self.cfg.device == 'gpu'
            buffer = self.tensor_batcher.cat(buffer, macro_batch_size,
                                             use_pinned_memory, timing)

        with timing.add_time('buff_ready'):
            for r in rollouts:
                self._mark_rollout_buffer_free(r)

        with timing.add_time('tensors_gpu_float'):
            device_buffer = self._copy_train_data_to_device(buffer)

        with timing.add_time('squeeze'):
            # will squeeze actions only in simple categorical case
            tensors_to_squeeze = [
                'actions', 'log_prob_actions', 'policy_version', 'values',
                'rewards', 'dones'
            ]
            for tensor_name in tensors_to_squeeze:
                device_buffer[tensor_name].squeeze_()

        # we no longer need the cached buffer, and can put it back into the pool
        self.tensor_batch_pool.put(buffer)
        return device_buffer
Example #6
0
    def init(self):
        for env_i in range(self.num_envs):
            vector_idx = self.split_idx * self.num_envs + env_i

            # global env id within the entire system
            env_id = self.worker_idx * self.cfg.num_envs_per_worker + vector_idx

            env_config = AttrDict(
                worker_index=self.worker_idx, vector_index=vector_idx, env_id=env_id,
            )

            # log.info('Creating env %r... %d-%d-%d', env_config, self.worker_idx, self.split_idx, env_i)
            env = make_env_func(self.cfg, env_config=env_config)

            env.seed(env_id)
            self.envs.append(env)

            actor_states_env, episode_rewards_env = [], []
            for agent_idx in range(self.num_agents):
                agent_traj_tensors = self.traj_tensors.index((env_i, agent_idx))
                actor_state = ActorState(
                    self.cfg, env, self.worker_idx, self.split_idx, env_i, agent_idx,
                    agent_traj_tensors, self.num_traj_buffers,
                    self.policy_outputs, self.policy_output_tensors[env_i, agent_idx],
                    self.pbt_reward_shaping, self.policy_mgr,
                )
                actor_states_env.append(actor_state)
                episode_rewards_env.append(0.0)

            self.actor_states.append(actor_states_env)
            self.episode_rewards.append(episode_rewards_env)
Example #7
0
    def _extract_rollouts(self, data):
        data = AttrDict(data)
        worker_idx, split_idx, traj_buffer_idx = data.worker_idx, data.split_idx, data.traj_buffer_idx

        rollouts = []
        for rollout_data in data.rollouts:
            env_idx, agent_idx = rollout_data['env_idx'], rollout_data['agent_idx']
            tensors = self.rollout_tensors.index((worker_idx, split_idx, env_idx, agent_idx, traj_buffer_idx))

            rollout_data['t'] = tensors
            rollout_data['worker_idx'] = worker_idx
            rollout_data['split_idx'] = split_idx
            rollout_data['traj_buffer_idx'] = traj_buffer_idx
            rollouts.append(AttrDict(rollout_data))

        return rollouts
Example #8
0
def load_from_checkpoint(cfg):
    filename = cfg_file(cfg)
    if not os.path.isfile(filename):
        raise Exception(
            f'Could not load saved parameters for experiment {cfg.experiment}')

    with open(filename, 'r') as json_file:
        json_params = json.load(json_file)
        log.warning('Loading existing experiment configuration from %s',
                    filename)
        loaded_cfg = AttrDict(json_params)

    # override the parameters in config file with values passed from command line
    for key, value in cfg.cli_args.items():
        if key in loaded_cfg and loaded_cfg[key] != value:
            log.debug(
                'Overriding arg %r with value %r passed from command line',
                key, value)
            loaded_cfg[key] = value

    # incorporate extra CLI parameters that were not present in JSON file
    for key, value in vars(cfg).items():
        if key not in loaded_cfg:
            log.debug(
                'Adding new argument %r=%r that is not in the saved config file!',
                key, value)
            loaded_cfg[key] = value

    return loaded_cfg
Example #9
0
    def print_stats(self, fps, sample_throughput, total_env_steps):
        fps_str = []
        for interval, fps_value in zip(self.avg_stats_intervals, fps):
            fps_str.append(
                f'{int(interval * self.report_interval)} sec: {fps_value:.1f}')
        fps_str = f'({", ".join(fps_str)})'

        samples_per_policy = ', '.join(
            [f'{p}: {s:.1f}' for p, s in sample_throughput.items()])

        lag_stats = self.policy_lag[0]
        lag = AttrDict()
        for key in ['min', 'avg', 'max']:
            lag[key] = lag_stats.get(f'version_diff_{key}', -1)
        policy_lag_str = f'min: {lag.min:.1f}, avg: {lag.avg:.1f}, max: {lag.max:.1f}'

        log.debug(
            'Fps is %s. Total num frames: %d. Throughput: %s. Samples: %d. Policy #0 lag: (%s)',
            fps_str,
            total_env_steps,
            samples_per_policy,
            sum(self.samples_collected),
            policy_lag_str,
        )

        if 'reward' in self.policy_avg_stats:
            policy_reward_stats = []
            for policy_id in range(self.cfg.num_policies):
                reward_stats = self.policy_avg_stats['reward'][policy_id]
                if len(reward_stats) > 0:
                    policy_reward_stats.append(
                        (policy_id, f'{np.mean(reward_stats):.3f}'))
            log.debug('Avg episode reward: %r', policy_reward_stats)
Example #10
0
    def _get_minibatch(buffer, indices):
        if indices is None:
            # handle the case of a single batch, where the entire buffer is a minibatch
            return buffer

        mb = AttrDict()

        for item, x in buffer.items():
            if isinstance(x, (dict, OrderedDict)):
                mb[item] = AttrDict()
                for key, x_elem in x.items():
                    mb[item][key] = x_elem[indices]
            else:
                mb[item] = x[indices]

        return mb
Example #11
0
def maybe_load_from_checkpoint(cfg):
    filename = cfg_file(cfg)
    if not os.path.isfile(filename):
        log.warning('Saved parameter configuration for experiment %s not found!', cfg.experiment)
        log.warning('Starting experiment from scratch!')
        return AttrDict(vars(cfg))

    return load_from_checkpoint(cfg)
Example #12
0
    def forward_tail(self, core_output):
        q_values = self.q_tail(core_output)

        result = AttrDict(dict(
            q_values=q_values,
        ))

        return result
Example #13
0
    def _objectives(self):
        # model losses
        l2_loss_obs = tf.squared_difference(self.tgt_features, self.predicted_features)
        #one axis must have dimension None
        prediction_loss = tf.reduce_mean(l2_loss_obs, axis=-1)
        bonus = prediction_loss

        loss = prediction_loss * self.params.prediction_loss_scale
        return AttrDict(locals())
Example #14
0
    def forward_tail(self, core_output, with_action_distribution=False):
        values = self.critic_linear(core_output)

        action_distribution_params, action_distribution = self.action_parameterization(core_output)

        # for non-trivial action spaces it is faster to do these together
        actions, log_prob_actions = sample_actions_log_probs(action_distribution)

        result = AttrDict(dict(
            actions=actions,
            action_logits=action_distribution_params,  # perhaps `action_logits` is not the best name here since we now support continuous actions
            log_prob_actions=log_prob_actions,
            values=values,
        ))

        if with_action_distribution:
            result.action_distribution = action_distribution

        return result
Example #15
0
def generate_env_map(make_env_func):
    """
    Currently only Doom environments support this.
    We have to initialize the env instance in a separate process because otherwise Doom overrides the signal handler
    and we cannot do things like KeyboardInterrupt anympore.
    """
    manager = multiprocessing.Manager()
    return_dict = manager.dict()
    p = multiprocessing.Process(target=_generate_env_map_worker,
                                args=(make_env_func, return_dict))
    p.start()
    p.join()

    return_dict = AttrDict(return_dict)
    return return_dict.map_img, return_dict.coord_limits
Example #16
0
 def _init(self, envs):
     log.info('Initializing envs %s...', list_to_string(self.env_indices))
     worker_index = self.env_indices[0] // len(self.env_indices)
     for i in self.env_indices:
         env_config = AttrDict({
             'worker_index': worker_index,
             'vector_index': i - self.env_indices[0]
         })
         env = self.make_env_func(env_config)
         env.seed(i)
         env.reset()
         if hasattr(env, 'num_agents') and env.num_agents > 1:
             self.is_multiagent = True
         envs.append(env)
         time.sleep(0.01)
Example #17
0
def config(debug, cache={}):
    d = lambda _debug, _prod: _debug if debug else os.environ[_prod]
    if cache:
        return cache['config']
    else:
        cache['config'] = AttrDict.from_data({
            'DB': {
                'name': d('sourandperky', 'DB.name'),
                'user': d('sourandperky', 'DB.user'),
                'password': d('sourandperky', 'DB.password'),
                'host': d('127.0.0.1', 'DB.host'),
                'port': d('5432', 'DB.port'),
            },
            'SECRET_KEY':
            d('SECRET_KEY', 'SECRET_KEY')
        })
        return cache['config']
    def add_ppo_objectives(actor_critic, actions, old_action_probs, advantages, returns, params, step):
        action_probs = actor_critic.actions_distribution.probability(actions)
        prob_ratio = action_probs / old_action_probs  # pi / pi_old

        clip_ratio = params.ppo_clip_ratio
        clipped_advantages = tf.where(advantages > 0, advantages * clip_ratio, advantages / clip_ratio)

        clipped = tf.logical_or(prob_ratio > clip_ratio, prob_ratio < 1.0 / clip_ratio)
        clipped = tf.cast(clipped, tf.float32)

        # PPO policy gradient loss
        ppo_loss = tf.reduce_mean(-tf.minimum(prob_ratio * advantages, clipped_advantages))

        # penalize for inaccurate value estimation
        value_loss = tf.reduce_mean(tf.square(returns - actor_critic.value))

        # penalize the agent for being "too sure" about it's actions (to prevent converging to the suboptimal local
        # minimum too soon)
        entropy_losses = actor_critic.actions_distribution.entropy()

        # make sure entropy is maximized only for state-action pairs with non-clipped advantage
        entropy_losses = (1.0 - clipped) * entropy_losses
        entropy_loss = -tf.reduce_mean(entropy_losses)
        entropy_loss_coeff = tf.train.exponential_decay(
            params.initial_entropy_loss_coeff, tf.cast(step, tf.float32), 10.0, 0.95, staircase=True,
        )
        entropy_loss_coeff = tf.maximum(entropy_loss_coeff, params.min_entropy_loss_coeff)
        entropy_loss = entropy_loss_coeff * entropy_loss

        # auxiliary quantities (for tensorboard, logging, early stopping)
        log_p_old = tf.log(old_action_probs + EPS)
        log_p = tf.log(action_probs + EPS)
        sample_kl = tf.reduce_mean(log_p_old - log_p)
        sample_entropy = tf.reduce_mean(-log_p)
        clipped_fraction = tf.reduce_mean(clipped)

        # only use entropy bonus if the policy is not close to max entropy
        max_entropy = actor_critic.actions_distribution.max_entropy()
        entropy_loss = tf.cond(sample_entropy > 0.8 * max_entropy, lambda: 0.0, lambda: entropy_loss)

        # final losses to optimize
        actor_loss = ppo_loss + entropy_loss
        critic_loss = value_loss

        return AttrDict(locals())
Example #19
0
def test_env_performance(make_env, env_type, verbose=False):
    t = Timing()
    with t.timeit('init'):
        env = make_env(AttrDict({'worker_index': 0, 'vector_index': 0}))
        total_num_frames, frames = 10000, 0

    with t.timeit('first_reset'):
        env.reset()

    t.reset = t.step = 1e-9
    num_resets = 0
    with t.timeit('experience'):
        while frames < total_num_frames:
            done = False

            start_reset = time.time()
            env.reset()

            t.reset += time.time() - start_reset
            num_resets += 1

            while not done and frames < total_num_frames:
                start_step = time.time()
                if verbose:
                    env.render()
                    time.sleep(1.0 / 40)

                obs, rew, done, info = env.step(env.action_space.sample())
                if verbose:
                    log.info('Received reward %.3f', rew)

                t.step += time.time() - start_step
                frames += num_env_steps([info])

    fps = total_num_frames / t.experience
    log.debug('%s performance:', env_type)
    log.debug('Took %.3f sec to collect %d frames on one CPU, %.1f FPS',
              t.experience,
              total_num_frames,
              fps)
    log.debug('Avg. reset time %.3f s', t.reset / num_resets)
    log.debug('Timing: %s', t)
    env.close()
Example #20
0
    def _objectives(self):
        # model losses
        forward_loss_batch = 0.5 * tf.square(self.encoded_next_obs -
                                             self.predicted_features)
        forward_loss_batch = tf.reduce_mean(forward_loss_batch,
                                            axis=1) * self.feature_vector_size
        forward_loss = tf.reduce_mean(forward_loss_batch)

        bonus = self.params.prediction_bonus_coeff * forward_loss_batch
        self.prediction_curiosity_bonus = tf.clip_by_value(
            bonus, -self.params.clip_bonus, self.params.clip_bonus)

        inverse_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=self.predicted_actions,
                labels=self.ph_actions,
            ))

        cm_beta = self.params.cm_beta
        loss = forward_loss * cm_beta + inverse_loss * (1.0 - cm_beta)
        loss = self.params.cm_lr_scale * loss
        return AttrDict(locals())
Example #21
0
def load_from_checkpoint(cfg):
    filename = cfg_file(cfg)
    if not os.path.isfile(filename):
        raise Exception(
            f'Could not load saved parameters for experiment {cfg.experiment}')

    with open(filename, 'r') as json_file:
        json_params = json.load(json_file)
        log.warning('Loading existing experiment configuration from %s',
                    filename)
        log.warning(
            'Command-line parameters will be ignored!\n'
            'If you want to resume experiment with different parameters, you should edit %s!',
            filename,
        )
        loaded_cfg = AttrDict(json_params)

    # incorporate extra CLI parameters that were not present in JSON file
    for key, value in vars(cfg).items():
        if key not in loaded_cfg:
            loaded_cfg[key] = value

    return loaded_cfg
    def doom_multiagent(make_multi_env, worker_index, num_steps=1000):
        env_config = AttrDict({
            'worker_index': worker_index,
            'vector_index': 0,
            'safe_init': False
        })
        multi_env = make_multi_env(env_config)

        obs = multi_env.reset()

        visualize = False
        start = time.time()

        for i in range(num_steps):
            actions = [multi_env.action_space.sample()] * len(obs)
            obs, rew, dones, infos = multi_env.step(actions)

            if visualize:
                multi_env.render()

            if i % 100 == 0 or any(dones):
                log.info('Rew %r done %r info %r', rew, dones, infos)

            if all(dones):
                multi_env.reset()

        took = time.time() - start
        log.info('Took %.3f seconds for %d steps', took, num_steps)
        log.info('Server steps per second: %.1f', num_steps / took)
        log.info('Observations fps: %.1f',
                 num_steps * multi_env.num_agents / took)
        log.info(
            'Environment fps: %.1f',
            num_steps * multi_env.num_agents * multi_env.skip_frames / took)

        multi_env.close()
    def setup_graph(env, params, use_dataset):
        tf.reset_default_graph()

        step = tf.Variable(0, trainable=False, dtype=tf.int64, name='step')

        ph_observations = placeholder_from_space(env.observation_space)
        ph_actions = placeholder_from_space(env.action_space)
        ph_old_actions_probs, ph_advantages, ph_returns = placeholders(
            None, None, None)

        if use_dataset:
            dataset = tf.data.Dataset.from_tensor_slices((
                ph_observations,
                ph_actions,
                ph_old_actions_probs,
                ph_advantages,
                ph_returns,
            ))
            dataset = dataset.batch(params.batch_size)
            dataset = dataset.prefetch(10)
            iterator = dataset.make_initializable_iterator()
            observations, act, old_action_probs, adv, ret = iterator.get_next()
        else:
            observations = ph_observations
            act, old_action_probs, adv, ret = ph_actions, ph_old_actions_probs, ph_advantages, ph_returns

        actor_critic = ActorCritic(env, observations, params)
        env.close()

        objectives = AgentPPO.add_ppo_objectives(actor_critic, act,
                                                 old_action_probs, adv, ret,
                                                 params, step)
        train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(
            objectives.actor_loss, global_step=step)

        return AttrDict(locals())
    def test_performance(self):
        params = AgentPPO.Params('test_performance')
        params.ppo_epochs = 2
        params.rollout = 16
        env = make_doom_env(doom_env_by_name(TEST_ENV_NAME))

        observation_shape = env.observation_space.shape
        experience_size = params.num_envs * params.rollout

        # generate random data
        data = AttrDict()
        data.obs = np.random.normal(size=(experience_size, ) +
                                    observation_shape)
        data.act = np.random.randint(0, 3, size=[experience_size])
        data.old_prob = np.random.uniform(0, 1, size=[experience_size])
        data.adv = np.random.normal(size=[experience_size])
        data.ret = np.random.normal(size=[experience_size])

        self.train_feed_dict(env, data, params, use_gpu=False)
        self.train_feed_dict(env, data, params, use_gpu=True)
        self.train_dataset(env, data, params, use_gpu=False)
        self.train_dataset(env, data, params, use_gpu=True)

        env.close()
def enjoy(cfg, max_num_frames=1e9):
    cfg = load_from_checkpoint(cfg)

    render_action_repeat = cfg.render_action_repeat if cfg.render_action_repeat is not None else cfg.env_frameskip
    if render_action_repeat is None:
        log.warning('Not using action repeat!')
        render_action_repeat = 1
    log.debug('Using action repeat %d during evaluation', render_action_repeat)

    cfg.env_frameskip = 1  # for evaluation
    cfg.num_envs = 1

    if cfg.record_to:
        tstamp = datetime.datetime.now().strftime('%Y_%m_%d__%H_%M_%S')
        cfg.record_to = join(cfg.record_to, f'{cfg.experiment}', tstamp)
        if not os.path.isdir(cfg.record_to):
            os.makedirs(cfg.record_to)
    else:
        cfg.record_to = None

    def make_env_func(env_config):
        return create_env(cfg.env, cfg=cfg, env_config=env_config)

    env = make_env_func(AttrDict({'worker_index': 0, 'vector_index': 0}))
    # env.seed(0)

    is_multiagent = is_multiagent_env(env)
    if not is_multiagent:
        env = MultiAgentWrapper(env)

    if hasattr(env.unwrapped, 'reset_on_init'):
        # reset call ruins the demo recording for VizDoom
        env.unwrapped.reset_on_init = False

    actor_critic = create_actor_critic(cfg, env.observation_space,
                                       env.action_space)

    device = torch.device('cpu' if cfg.device == 'cpu' else 'cuda')
    actor_critic.model_to_device(device)

    policy_id = cfg.policy_index
    checkpoints = LearnerWorker.get_checkpoints(
        LearnerWorker.checkpoint_dir(cfg, policy_id))
    checkpoint_dict = LearnerWorker.load_checkpoint(checkpoints, device)
    actor_critic.load_state_dict(checkpoint_dict['model'])

    episode_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)]
    true_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)]
    num_frames = 0

    last_render_start = time.time()

    def max_frames_reached(frames):
        return max_num_frames is not None and frames > max_num_frames

    obs = env.reset()
    rnn_states = torch.zeros(
        [env.num_agents, get_hidden_size(cfg)],
        dtype=torch.float32,
        device=device)
    episode_reward = np.zeros(env.num_agents)
    finished_episode = [False] * env.num_agents

    with torch.no_grad():
        while not max_frames_reached(num_frames):
            obs_torch = AttrDict(transform_dict_observations(obs))
            for key, x in obs_torch.items():
                obs_torch[key] = torch.from_numpy(x).to(device).float()

            policy_outputs = actor_critic(obs_torch,
                                          rnn_states,
                                          with_action_distribution=True)

            # sample actions from the distribution by default
            actions = policy_outputs.actions

            action_distribution = policy_outputs.action_distribution
            if isinstance(action_distribution, ContinuousActionDistribution):
                if not cfg.continuous_actions_sample:  # TODO: add similar option for discrete actions
                    actions = action_distribution.means

            actions = actions.cpu().numpy()

            rnn_states = policy_outputs.rnn_states

            for _ in range(render_action_repeat):
                if not cfg.no_render:
                    target_delay = 1.0 / cfg.fps if cfg.fps > 0 else 0
                    current_delay = time.time() - last_render_start
                    time_wait = target_delay - current_delay

                    if time_wait > 0:
                        # log.info('Wait time %.3f', time_wait)
                        time.sleep(time_wait)

                    last_render_start = time.time()
                    env.render()

                obs, rew, done, infos = env.step(actions)

                episode_reward += rew
                num_frames += 1

                for agent_i, done_flag in enumerate(done):
                    if done_flag:
                        finished_episode[agent_i] = True
                        episode_rewards[agent_i].append(
                            episode_reward[agent_i])
                        true_rewards[agent_i].append(infos[agent_i].get(
                            'true_reward', math.nan))
                        log.info(
                            'Episode finished for agent %d at %d frames. Reward: %.3f, true_reward: %.3f',
                            agent_i, num_frames, episode_reward[agent_i],
                            true_rewards[agent_i][-1])
                        rnn_states[agent_i] = torch.zeros(
                            [get_hidden_size(cfg)],
                            dtype=torch.float32,
                            device=device)
                        episode_reward[agent_i] = 0

                # if episode terminated synchronously for all agents, pause a bit before starting a new one
                if all(done):
                    if not cfg.no_render:
                        env.render()
                    time.sleep(0.05)

                if all(finished_episode):
                    finished_episode = [False] * env.num_agents
                    avg_episode_rewards_str, avg_true_reward_str = '', ''
                    for agent_i in range(env.num_agents):
                        avg_rew = np.mean(episode_rewards[agent_i])
                        avg_true_rew = np.mean(true_rewards[agent_i])
                        if not np.isnan(avg_rew):
                            if avg_episode_rewards_str:
                                avg_episode_rewards_str += ', '
                            avg_episode_rewards_str += f'#{agent_i}: {avg_rew:.3f}'
                        if not np.isnan(avg_true_rew):
                            if avg_true_reward_str:
                                avg_true_reward_str += ', '
                            avg_true_reward_str += f'#{agent_i}: {avg_true_rew:.3f}'

                    log.info('Avg episode rewards: %s, true rewards: %s',
                             avg_episode_rewards_str, avg_true_reward_str)
                    log.info(
                        'Avg episode reward: %.3f, avg true_reward: %.3f',
                        np.mean([
                            np.mean(episode_rewards[i])
                            for i in range(env.num_agents)
                        ]),
                        np.mean([
                            np.mean(true_rewards[i])
                            for i in range(env.num_agents)
                        ]))

                # VizDoom multiplayer stuff
                # for player in [1, 2, 3, 4, 5, 6, 7, 8]:
                #     key = f'PLAYER{player}_FRAGCOUNT'
                #     if key in infos[0]:
                #         log.debug('Score for player %d: %r', player, infos[0][key])

    env.close()

    return ExperimentStatus.SUCCESS, np.mean(episode_rewards)
Example #26
0
    def sample(self, proc_idx):
        # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        timing = Timing()

        from threadpoolctl import threadpool_limits
        with threadpool_limits(limits=1, user_api=None):
            if self.cfg.set_workers_cpu_affinity:
                set_process_cpu_affinity(proc_idx, self.cfg.num_workers)

            initial_cpu_affinity = psutil.Process().cpu_affinity(
            ) if platform != 'darwin' else None
            psutil.Process().nice(10)

            with timing.timeit('env_init'):
                envs = []
                env_key = ['env' for _ in range(self.cfg.num_envs_per_worker)]

                for env_idx in range(self.cfg.num_envs_per_worker):
                    global_env_id = proc_idx * self.cfg.num_envs_per_worker + env_idx
                    env_config = AttrDict(worker_index=proc_idx,
                                          vector_index=env_idx,
                                          env_id=global_env_id)
                    env = create_env(self.cfg.env,
                                     cfg=self.cfg,
                                     env_config=env_config)
                    log.debug(
                        'CPU affinity after create_env: %r',
                        psutil.Process().cpu_affinity()
                        if platform != 'darwin' else 'MacOS - None')
                    env.seed(global_env_id)
                    envs.append(env)

                    # this is to track the performance for individual DMLab levels
                    if hasattr(env.unwrapped, 'level_name'):
                        env_key[env_idx] = env.unwrapped.level_name

                episode_length = [0 for _ in envs]
                episode_lengths = [deque([], maxlen=20) for _ in envs]

            try:
                with timing.timeit('first_reset'):
                    for env_idx, env in enumerate(envs):
                        env.reset()
                        log.info('Process %d finished resetting %d/%d envs',
                                 proc_idx, env_idx + 1, len(envs))

                    self.report_queue.put(
                        dict(proc_idx=proc_idx, finished_reset=True))

                self.start_event.wait()

                with timing.timeit('work'):
                    last_report = last_report_frames = total_env_frames = 0
                    while not self.terminate.value and total_env_frames < self.cfg.sample_env_frames_per_worker:
                        for env_idx, env in enumerate(envs):
                            action = env.action_space.sample()
                            with timing.add_time(f'{env_key[env_idx]}.step'):
                                obs, reward, done, info = env.step(action)

                            num_frames = info.get('num_frames', 1)
                            total_env_frames += num_frames
                            episode_length[env_idx] += num_frames

                            if done:
                                with timing.add_time(
                                        f'{env_key[env_idx]}.reset'):
                                    env.reset()

                                episode_lengths[env_idx].append(
                                    episode_length[env_idx])
                                episode_length[env_idx] = 0

                        with timing.add_time('report'):
                            now = time.time()
                            if now - last_report > self.report_every_sec:
                                last_report = now
                                frames_since_last_report = total_env_frames - last_report_frames
                                last_report_frames = total_env_frames
                                self.report_queue.put(
                                    dict(proc_idx=proc_idx,
                                         env_frames=frames_since_last_report))

                # Extra check to make sure cpu affinity is preserved throughout the execution.
                # I observed weird effect when some environments tried to alter affinity of the current process, leading
                # to decreased performance.
                # This can be caused by some interactions between deep learning libs, OpenCV, MKL, OpenMP, etc.
                # At least user should know about it if this is happening.
                cpu_affinity = psutil.Process().cpu_affinity(
                ) if platform != 'darwin' else None
                assert initial_cpu_affinity == cpu_affinity, \
                    f'Worker CPU affinity was changed from {initial_cpu_affinity} to {cpu_affinity}!' \
                    f'This can significantly affect performance!'

            except:
                log.exception('Unknown exception')
                log.error('Unknown exception in worker %d, terminating...',
                          proc_idx)
                self.report_queue.put(dict(proc_idx=proc_idx, crash=True))

            time.sleep(proc_idx * 0.01 + 0.01)
            log.info('Process %d finished sampling. Timing: %s', proc_idx,
                     timing)

            for env_idx, env in enumerate(envs):
                if len(episode_lengths[env_idx]) > 0:
                    log.warning('Level %s avg episode len %d',
                                env_key[env_idx],
                                np.mean(episode_lengths[env_idx]))

            for env in envs:
                env.close()
Example #27
0
    def _learn_loop(self, multi_env, step_callback=None):
        """
        Main training loop.
        :param step_callback: a hacky callback that takes a dictionary with all local variables as an argument.
        Allows you too look inside the training process.
        """
        step = initial_step = tf.train.global_step(self.session, tf.train.get_global_step())
        env_steps = self.total_env_steps.eval(session=self.session)
        batch_size = self.params.rollout * self.params.num_envs

        img_obs, timer_obs = extract_keys(multi_env.initial_obs(), 'obs', 'timer')

        adv_running_mean_std = RunningMeanStd(max_past_samples=10000)

        def end_of_training(s, es):
            return s >= self.params.train_for_steps or es > self.params.train_for_env_steps

        while not end_of_training(step, env_steps):
            timing = AttrDict({'experience': time.time(), 'batch': time.time()})
            experience_start = time.time()

            env_steps_before_batch = env_steps
            batch_obs, batch_timer = [img_obs], [timer_obs]
            env_steps += len(img_obs)
            batch_actions, batch_values, batch_rewards, batch_dones, batch_next_obs = [], [], [], [], []
            for rollout_step in range(self.params.rollout):
                actions, values = self._policy_step_timer(img_obs, timer_obs)
                batch_actions.append(actions)
                batch_values.append(values)

                # wait for all the workers to complete an environment step
                next_obs, rewards, dones, infos = multi_env.step(actions)
                next_img_obs, next_timer = extract_keys(next_obs, 'obs', 'timer')

                # calculate curiosity bonus
                bonuses = self._prediction_curiosity_bonus(img_obs, actions, next_img_obs)
                rewards += bonuses

                batch_rewards.append(rewards)
                batch_dones.append(dones)
                batch_next_obs.append(next_img_obs)

                img_obs = next_img_obs
                timer_obs = next_timer

                if infos is not None and 'num_frames' in infos[0]:
                    env_steps += sum((info['num_frames'] for info in infos))
                else:
                    env_steps += multi_env.num_envs

                if rollout_step != self.params.rollout - 1:
                    # we don't need the newest observation in the training batch, already have enough
                    batch_obs.append(img_obs)
                    batch_timer.append(timer_obs)

            assert len(batch_obs) == len(batch_rewards)
            assert len(batch_obs) == len(batch_next_obs)

            batch_rewards = np.asarray(batch_rewards, np.float32).swapaxes(0, 1)
            batch_dones = np.asarray(batch_dones, np.bool).swapaxes(0, 1)
            batch_values = np.asarray(batch_values, np.float32).swapaxes(0, 1)

            # Last value won't be valid for envs with done=True (because env automatically resets and shows 1st
            # observation of the next episode. But that's okay, because we should never use last_value in this case.
            last_values = self._estimate_values_timer(img_obs, timer_obs)

            gamma = self.params.gamma
            disc_rewards = []
            for i in range(len(batch_rewards)):
                env_rewards = self._calc_discounted_rewards(gamma, batch_rewards[i], batch_dones[i], last_values[i])
                disc_rewards.extend(env_rewards)
            disc_rewards = np.asarray(disc_rewards, np.float32)

            # convert observations and estimations to meaningful n-step batches
            batch_obs_shape = (self.params.rollout * multi_env.num_envs,) + img_obs[0].shape
            batch_obs = np.asarray(batch_obs, np.float32).swapaxes(0, 1).reshape(batch_obs_shape)
            batch_next_obs = np.asarray(batch_next_obs, np.float32).swapaxes(0, 1).reshape(batch_obs_shape)
            batch_actions = np.asarray(batch_actions, np.int32).swapaxes(0, 1).flatten()
            batch_timer = np.asarray(batch_timer, np.float32).swapaxes(0, 1).flatten()
            batch_values = batch_values.flatten()

            advantages = disc_rewards - batch_values
            if self.params.normalize_adv:
                adv_running_mean_std.update(advantages)
                advantages = (advantages - adv_running_mean_std.mean) / (np.sqrt(adv_running_mean_std.var) + EPS)
            advantages = np.clip(advantages, -self.params.clip_advantage, self.params.clip_advantage)

            timing.experience = time.time() - timing.experience
            timing.train = time.time()

            step = self._curious_train_step(
                step,
                env_steps,
                batch_obs,
                batch_timer,
                batch_actions,
                batch_values,
                disc_rewards,
                advantages,
                batch_next_obs,
            )
            self._maybe_save(step, env_steps)

            timing.train = time.time() - timing.train

            avg_reward = multi_env.calc_avg_rewards(n=self.params.stats_episodes)
            avg_length = multi_env.calc_avg_episode_lengths(n=self.params.stats_episodes)
            fps = (env_steps - env_steps_before_batch) / (time.time() - timing.batch)

            self._maybe_print(step, avg_reward, avg_length, fps, timing)
            self._maybe_aux_summaries(step, env_steps, avg_reward, avg_length)
            self._maybe_update_avg_reward(avg_reward, multi_env.stats_num_episodes())

            if step_callback is not None:
                step_callback(locals(), globals())
Example #28
0
    def _train(self, gpu_buffer, batch_size, experience_size, timing):
        with torch.no_grad():
            early_stopping_tolerance = 1e-6
            early_stop = False
            prev_epoch_actor_loss = 1e9
            epoch_actor_losses = []

            # V-trace parameters
            # noinspection PyArgumentList
            rho_hat = torch.Tensor([self.cfg.vtrace_rho])
            # noinspection PyArgumentList
            c_hat = torch.Tensor([self.cfg.vtrace_c])

            clip_ratio_high = 1.0 + self.cfg.ppo_clip_ratio  # e.g. 1.1
            # this still works with e.g. clip_ratio = 2, while PPO's 1-r would give negative ratio
            clip_ratio_low = 1.0 / clip_ratio_high

            clip_value = self.cfg.ppo_clip_value
            gamma = self.cfg.gamma
            recurrence = self.cfg.recurrence

            if self.cfg.with_vtrace:
                assert recurrence == self.cfg.rollout and recurrence > 1, \
                    'V-trace requires to recurrence and rollout to be equal'

            num_sgd_steps = 0

            stats_and_summaries = None
            if not self.with_training:
                return stats_and_summaries

        for epoch in range(self.cfg.ppo_epochs):
            with timing.add_time('epoch_init'):
                if early_stop or self.terminate:
                    break

                summary_this_epoch = force_summaries = False

                minibatches = self._get_minibatches(batch_size,
                                                    experience_size)

            for batch_num in range(len(minibatches)):
                with timing.add_time('minibatch_init'):
                    indices = minibatches[batch_num]

                    # current minibatch consisting of short trajectory segments with length == recurrence
                    mb = self._get_minibatch(gpu_buffer, indices)

                # calculate policy head outside of recurrent loop
                with timing.add_time('forward_head'):
                    head_outputs = self.actor_critic.forward_head(mb.obs)

                # initial rnn states
                with timing.add_time('bptt_initial'):
                    rnn_states = mb.rnn_states[::recurrence]
                    is_same_episode = 1.0 - mb.dones.unsqueeze(dim=1)

                # calculate RNN outputs for each timestep in a loop
                with timing.add_time('bptt'):
                    core_outputs = []
                    for i in range(recurrence):
                        # indices of head outputs corresponding to the current timestep
                        step_head_outputs = head_outputs[i::recurrence]

                        with timing.add_time('bptt_forward_core'):
                            core_output, rnn_states = self.actor_critic.forward_core(
                                step_head_outputs, rnn_states)
                            core_outputs.append(core_output)

                        if self.cfg.use_rnn:
                            # zero-out RNN states on the episode boundary
                            with timing.add_time('bptt_rnn_states'):
                                is_same_episode_step = is_same_episode[
                                    i::recurrence]
                                rnn_states = rnn_states * is_same_episode_step

                with timing.add_time('tail'):
                    # transform core outputs from [T, Batch, D] to [Batch, T, D] and then to [Batch x T, D]
                    # which is the same shape as the minibatch
                    core_outputs = torch.stack(core_outputs)

                    num_timesteps, num_trajectories = core_outputs.shape[:2]
                    assert num_timesteps == recurrence
                    assert num_timesteps * num_trajectories == batch_size
                    core_outputs = core_outputs.transpose(0, 1).reshape(
                        -1, *core_outputs.shape[2:])
                    assert core_outputs.shape[0] == head_outputs.shape[0]

                    # calculate policy tail outside of recurrent loop
                    result = self.actor_critic.forward_tail(
                        core_outputs, with_action_distribution=True)

                    action_distribution = result.action_distribution
                    log_prob_actions = action_distribution.log_prob(mb.actions)
                    ratio = torch.exp(log_prob_actions -
                                      mb.log_prob_actions)  # pi / pi_old

                    # super large/small values can cause numerical problems and are probably noise anyway
                    ratio = torch.clamp(ratio, 0.05, 20.0)

                    values = result.values.squeeze()

                with torch.no_grad(
                ):  # these computations are not the part of the computation graph
                    if self.cfg.with_vtrace:
                        ratios_cpu = ratio.cpu()
                        values_cpu = values.cpu()
                        rewards_cpu = mb.rewards.cpu(
                        )  # we only need this on CPU, potential minor optimization
                        dones_cpu = mb.dones.cpu()

                        vtrace_rho = torch.min(rho_hat, ratios_cpu)
                        vtrace_c = torch.min(c_hat, ratios_cpu)

                        vs = torch.zeros((num_trajectories * recurrence))
                        adv = torch.zeros((num_trajectories * recurrence))

                        next_values = (
                            values_cpu[recurrence - 1::recurrence] -
                            rewards_cpu[recurrence - 1::recurrence]) / gamma
                        next_vs = next_values

                        with timing.add_time('vtrace'):
                            for i in reversed(range(self.cfg.recurrence)):
                                rewards = rewards_cpu[i::recurrence]
                                dones = dones_cpu[i::recurrence]
                                not_done = 1.0 - dones
                                not_done_times_gamma = not_done * gamma

                                curr_values = values_cpu[i::recurrence]
                                curr_vtrace_rho = vtrace_rho[i::recurrence]
                                curr_vtrace_c = vtrace_c[i::recurrence]

                                delta_s = curr_vtrace_rho * (
                                    rewards + not_done_times_gamma *
                                    next_values - curr_values)
                                adv[i::recurrence] = curr_vtrace_rho * (
                                    rewards + not_done_times_gamma * next_vs -
                                    curr_values)
                                next_vs = curr_values + delta_s + not_done_times_gamma * curr_vtrace_c * (
                                    next_vs - next_values)
                                vs[i::recurrence] = next_vs

                                next_values = curr_values

                        targets = vs
                    else:
                        # using regular GAE
                        adv = mb.advantages
                        targets = mb.returns

                    adv_mean = adv.mean()
                    adv_std = adv.std()
                    adv = (adv - adv_mean) / max(
                        1e-3, adv_std)  # normalize advantage
                    adv = adv.to(self.device)

                with timing.add_time('losses'):
                    policy_loss = self._policy_loss(ratio, adv, clip_ratio_low,
                                                    clip_ratio_high)

                    entropy = action_distribution.entropy()
                    if self.cfg.entropy_loss_coeff > 0.0:
                        entropy_loss = -self.cfg.entropy_loss_coeff * entropy.mean(
                        )
                    else:
                        entropy_loss = 0.0

                    actor_loss = policy_loss + entropy_loss
                    epoch_actor_losses.append(actor_loss.item())

                    targets = targets.to(self.device)
                    old_values = mb.values
                    value_loss = self._value_loss(values, old_values, targets,
                                                  clip_value)
                    critic_loss = value_loss

                    loss = actor_loss + critic_loss

                    high_loss = 30.0
                    if abs(to_scalar(policy_loss)) > high_loss or abs(
                            to_scalar(value_loss)) > high_loss or abs(
                                to_scalar(entropy_loss)) > high_loss:
                        log.warning(
                            'High loss value: %.4f %.4f %.4f %.4f',
                            to_scalar(loss),
                            to_scalar(policy_loss),
                            to_scalar(value_loss),
                            to_scalar(entropy_loss),
                        )
                        force_summaries = True

                with timing.add_time('update'):
                    # update the weights
                    self.optimizer.zero_grad()
                    loss.backward()

                    if self.cfg.max_grad_norm > 0.0:
                        with timing.add_time('clip'):
                            torch.nn.utils.clip_grad_norm_(
                                self.actor_critic.parameters(),
                                self.cfg.max_grad_norm)

                    curr_policy_version = self.train_step  # policy version before the weight update
                    with self.policy_lock:
                        self.optimizer.step()

                    num_sgd_steps += 1

                with torch.no_grad():
                    with timing.add_time('after_optimizer'):
                        self._after_optimizer_step()

                        # collect and report summaries
                        with_summaries = self._should_save_summaries(
                        ) or force_summaries
                        if with_summaries and not summary_this_epoch:
                            stats_and_summaries = self._record_summaries(
                                AttrDict(locals()))
                            summary_this_epoch = True
                            force_summaries = False

            # end of an epoch
            # this will force policy update on the inference worker (policy worker)
            self.policy_versions[self.policy_id] = self.train_step

            new_epoch_actor_loss = np.mean(epoch_actor_losses)
            loss_delta_abs = abs(prev_epoch_actor_loss - new_epoch_actor_loss)
            if loss_delta_abs < early_stopping_tolerance:
                early_stop = True
                log.debug(
                    'Early stopping after %d epochs (%d sgd steps), loss delta %.7f',
                    epoch + 1,
                    num_sgd_steps,
                    loss_delta_abs,
                )
                break

            prev_epoch_actor_loss = new_epoch_actor_loss
            epoch_actor_losses = []

        return stats_and_summaries
Example #29
0
    def _record_summaries(self, train_loop_vars):
        var = train_loop_vars

        self.last_summary_time = time.time()
        stats = AttrDict()

        grad_norm = sum(
            p.grad.data.norm(2).item()**2
            for p in self.actor_critic.parameters() if p.grad is not None)**0.5
        stats.grad_norm = grad_norm
        stats.loss = var.loss
        stats.value = var.result.values.mean()
        stats.entropy = var.action_distribution.entropy().mean()
        stats.policy_loss = var.policy_loss
        stats.value_loss = var.value_loss
        stats.entropy_loss = var.entropy_loss
        stats.adv_min = var.adv.min()
        stats.adv_max = var.adv.max()
        stats.adv_std = var.adv_std
        stats.max_abs_logprob = torch.abs(var.mb.action_logits).max()

        if hasattr(var.action_distribution, 'summaries'):
            stats.update(var.action_distribution.summaries())

        if var.epoch == self.cfg.ppo_epochs - 1 and var.batch_num == len(
                var.minibatches) - 1:
            # we collect these stats only for the last PPO batch, or every time if we're only doing one batch, IMPALA-style
            ratio_mean = torch.abs(1.0 - var.ratio).mean().detach()
            ratio_min = var.ratio.min().detach()
            ratio_max = var.ratio.max().detach()
            # log.debug('Learner %d ratio mean min max %.4f %.4f %.4f', self.policy_id, ratio_mean.cpu().item(), ratio_min.cpu().item(), ratio_max.cpu().item())

            value_delta = torch.abs(var.values - var.old_values)
            value_delta_avg, value_delta_max = value_delta.mean(
            ), value_delta.max()

            # calculate KL-divergence with the behaviour policy action distribution
            old_action_distribution = get_action_distribution(
                self.actor_critic.action_space,
                var.mb.action_logits,
            )
            kl_old = var.action_distribution.kl_divergence(
                old_action_distribution)
            kl_old_mean = kl_old.mean()

            stats.kl_divergence = kl_old_mean
            stats.value_delta = value_delta_avg
            stats.value_delta_max = value_delta_max
            stats.fraction_clipped = (
                (var.ratio < var.clip_ratio_low).float() +
                (var.ratio > var.clip_ratio_high).float()).mean()
            stats.ratio_mean = ratio_mean
            stats.ratio_min = ratio_min
            stats.ratio_max = ratio_max
            stats.num_sgd_steps = var.num_sgd_steps

        # this caused numerical issues on some versions of PyTorch with second moment reaching infinity
        adam_max_second_moment = 0.0
        for key, tensor_state in self.optimizer.state.items():
            adam_max_second_moment = max(
                tensor_state['exp_avg_sq'].max().item(),
                adam_max_second_moment)
        stats.adam_max_second_moment = adam_max_second_moment

        version_diff = var.curr_policy_version - var.mb.policy_version
        stats.version_diff_avg = version_diff.mean()
        stats.version_diff_min = version_diff.min()
        stats.version_diff_max = version_diff.max()

        for key, value in stats.items():
            stats[key] = to_scalar(value)

        return stats
def multi_agent_match(policy_indices, max_num_episodes=int(1e9), max_num_frames=1e10):
    log.debug('Starting eval process with policies %r', policy_indices)
    for i, rival in enumerate(RIVALS):
        rival.policy_index = policy_indices[i]

    curr_dir = os.path.dirname(os.path.abspath(__file__))
    evaluation_filename = join(curr_dir,
                               f'eval_{"vs".join([str(pi) for pi in policy_indices])}.txt')
    with open(evaluation_filename, 'w') as fobj:
        fobj.write('start\n')

    common_config = RIVALS[0].cfg

    render_action_repeat = common_config.render_action_repeat if common_config.render_action_repeat is not None else common_config.env_frameskip
    if render_action_repeat is None:
        log.warning('Not using action repeat!')
        render_action_repeat = 1
    log.debug('Using action repeat %d during evaluation', render_action_repeat)

    common_config.env_frameskip = 1  # for evaluation
    common_config.num_envs = 1
    common_config.timelimit = 4.0  # for faster evaluation

    def make_env_func(env_config):
        return create_env(ENV_NAME, cfg=common_config, env_config=env_config)

    env = make_env_func(AttrDict({'worker_index': 0, 'vector_index': 0}))
    env.seed(0)

    is_multiagent = is_multiagent_env(env)
    if not is_multiagent:
        env = MultiAgentWrapper(env)
    else:
        assert env.num_agents == len(RIVALS)

    device = torch.device('cuda')
    for rival in RIVALS:
        rival.actor_critic = create_actor_critic(rival.cfg, env.observation_space, env.action_space)
        rival.actor_critic.model_to_device(device)

        policy_id = rival.policy_index
        checkpoints = LearnerWorker.get_checkpoints(
            LearnerWorker.checkpoint_dir(rival.cfg, policy_id))
        checkpoint_dict = LearnerWorker.load_checkpoint(checkpoints, device)
        rival.actor_critic.load_state_dict(checkpoint_dict['model'])

    episode_rewards = []
    num_frames = 0

    last_render_start = time.time()

    def max_frames_reached(frames):
        return max_num_frames is not None and frames > max_num_frames

    wins = [0 for _ in RIVALS]
    ties = 0
    frag_differences = []

    with torch.no_grad():
        for _ in range(max_num_episodes):
            obs = env.reset()
            obs_dict_torch = dict()

            done = [False] * len(obs)
            for rival in RIVALS:
                rival.rnn_states = torch.zeros([1, rival.cfg.hidden_size],
                                               dtype=torch.float32,
                                               device=device)

            episode_reward = 0
            prev_frame = time.time()

            while True:
                actions = []
                for i, obs_dict in enumerate(obs):
                    for key, x in obs_dict.items():
                        obs_dict_torch[key] = torch.from_numpy(x).to(device).float().view(
                            1, *x.shape)

                    rival = RIVALS[i]
                    policy_outputs = rival.actor_critic(obs_dict_torch, rival.rnn_states)
                    rival.rnn_states = policy_outputs.rnn_states
                    actions.append(policy_outputs.actions[0].cpu().numpy())

                for _ in range(render_action_repeat):
                    if not NO_RENDER:
                        target_delay = 1.0 / FPS if FPS > 0 else 0
                        current_delay = time.time() - last_render_start
                        time_wait = target_delay - current_delay

                        if time_wait > 0:
                            # log.info('Wait time %.3f', time_wait)
                            time.sleep(time_wait)

                        last_render_start = time.time()
                        env.render()

                    obs, rew, done, infos = env.step(actions)
                    if all(done):
                        log.debug('Finished episode!')

                        frag_diff = infos[0]['PLAYER1_FRAGCOUNT'] - infos[0]['PLAYER2_FRAGCOUNT']
                        if frag_diff > 0:
                            wins[0] += 1
                        elif frag_diff < 0:
                            wins[1] += 1
                        else:
                            ties += 1

                        frag_differences.append(frag_diff)
                        avg_frag_diff = np.mean(frag_differences)

                        report = f'wins: {wins}, ties: {ties}, avg_frag_diff: {avg_frag_diff}'
                        with open(evaluation_filename, 'a') as fobj:
                            fobj.write(report + '\n')

                    # log.info('%d:%d', infos[0]['PLAYER1_FRAGCOUNT'], infos[0]['PLAYER2_FRAGCOUNT'])

                    episode_reward += np.mean(rew)
                    num_frames += 1

                    if num_frames % 100 == 0:
                        log.debug('%.1f', render_action_repeat / (time.time() - prev_frame))
                    prev_frame = time.time()

                    if all(done):
                        log.info('Episode finished at %d frames', num_frames)
                        break

                if all(done) or max_frames_reached(num_frames):
                    break

            if not NO_RENDER:
                env.render()
            time.sleep(0.01)

            episode_rewards.append(episode_reward)
            last_episodes = episode_rewards[-100:]
            avg_reward = sum(last_episodes) / len(last_episodes)
            log.info(
                'Episode reward: %f, avg reward for %d episodes: %f',
                episode_reward,
                len(last_episodes),
                avg_reward,
            )

            if max_frames_reached(num_frames):
                break

    env.close()