Esempio n. 1
0
def train(_run, exp_name, save_rate, display, restore_fp, hard_max,
          max_episode_len, num_episodes, batch_size, update_rate,
          use_target_action):
    """
    This is the main training function, which includes the setup and training loop.
    It is meant to be called automatically by sacred, but can be used without it as well.

    :param _run:            Sacred _run object for legging
    :param exp_name:        (str) Name of the experiment
    :param save_rate:       (int) Frequency to save networks at
    :param display:         (bool) Render the environment
    :param restore_fp:      (str)  File-Patch to policy to restore_fp or None if not wanted.
    :param hard_max:        (bool) Only output one action
    :param max_episode_len: (int) number of transitions per episode
    :param num_episodes:    (int) number of episodes
    :param batch_size:      (int) batch size for updates
    :param update_rate:     (int) perform critic update every n environment steps
    :param use_target_action:   (bool) use action from target network
    :return:    List of episodic rewards
    """
    # Create environment
    print(_run.config)
    env = make_env()

    # Create agents
    agents = get_agents(_run, env, env.n_adversaries)

    logger = RLLogger(exp_name, _run, len(agents), env.n_adversaries,
                      save_rate)

    # Load previous results, if necessary
    if restore_fp is not None:
        print('Loading previous state...')
        for ag_idx, agent in enumerate(agents):
            fp = os.path.join(restore_fp, 'agent_{}'.format(ag_idx))
            agent.load(fp)

    obs_n = env.reset()

    print('Starting iterations...')
    while True:
        # get action
        if use_target_action:
            action_n = [
                agent.target_action(obs.astype(np.float32)[None])[0]
                for agent, obs in zip(agents, obs_n)
            ]
        else:
            action_n = [
                agent.action(obs.astype(np.float32))
                for agent, obs in zip(agents, obs_n)
            ]
        # environment step
        if hard_max:
            hard_action_n = softmax_to_argmax(action_n, agents)
            new_obs_n, rew_n, done_n, info_n = env.step(hard_action_n)
        else:
            action_n = [action.numpy() for action in action_n]
            new_obs_n, rew_n, done_n, info_n = env.step(action_n)

        logger.episode_step += 1

        done = all(done_n)
        terminal = (logger.episode_step >= max_episode_len)
        done = done or terminal

        # collect experience
        for i, agent in enumerate(agents):
            agent.add_transition(obs_n, action_n, rew_n[i], new_obs_n, done)
        obs_n = new_obs_n

        for ag_idx, rew in enumerate(rew_n):
            logger.cur_episode_reward += rew
            logger.agent_rewards[ag_idx][-1] += rew

        if done:
            obs_n = env.reset()
            episode_step = 0
            logger.record_episode_end(agents)

        logger.train_step += 1

        # policy updates
        train_cond = not display
        for agent in agents:
            if train_cond and len(
                    agent.replay_buffer) > batch_size * max_episode_len:
                if logger.train_step % update_rate == 0:  # only update every 100 steps
                    q_loss, pol_loss = agent.update(agents, logger.train_step)

        # for displaying learned policies
        if display:
            time.sleep(0.1)
            env.render()

        # saves logger outputs to a file similar to the way in the original MADDPG implementation
        if len(logger.episode_rewards) > num_episodes:
            logger.experiment_end()
            return logger.get_sacred_results()
Esempio n. 2
0
def train(conf):
    env = make_env(conf.scenario)
    exp = SimpleExperiment(conf, 'tester', 12)
    logger = SimpleLogger('tester', exp, len(env.agents), env.n_adversaries,
                          conf.save_rate)

    agents = []
    for agent_idx in range(env.n_adversaries):
        agents.append(create_agent(conf.adv_policy, agent_idx, env, exp))
    for agent_idx in range(env.n_adversaries, env.n):
        agents.append(create_agent(conf.good_policy, agent_idx, env, exp))
    print(
        f'Using good policy {conf.good_policy} and adv policy {conf.adv_policy}'
    )

    # todo: Load previous results, if necessary

    obs_n = env.reset()

    print('Starting iterations...')
    while True:
        # get action
        if conf.use_target_action:  # note: what is target ????????
            # adding an extra axis to the observation
            action_n = [
                agent.target_action(obs.astype(np.float32)[None])[0]
                for agent, obs in zip(agents, obs_n)
            ]
        else:
            action_n = [
                agent.action(obs.astype(np.float32))
                for agent, obs in zip(agents, obs_n)
            ]

        # environment step
        if conf.hard_max:
            action_n = softmax_to_argmax(action_n, agents)
        else:
            action_n = [action.numpy() for action in action_n]

        new_obs_n, rew_n, done_n, info_n = env.step(action_n)

        logger.episode_step += 1

        done = all(done_n)
        terminal = (logger.episode_step >= conf.max_episode_len)
        done = done or terminal

        # collect experience
        for i, agent in enumerate(agents):
            agent.add_transition(obs_n, action_n, rew_n[i], new_obs_n, done)
        obs_n = new_obs_n

        for ag_idx, rew in enumerate(rew_n):
            logger.cur_episode_reward += rew
            logger.agent_rewards[ag_idx][-1] += rew

        if done:
            obs_n = env.reset()
            episode_step = 0
            logger.record_episode_end(agents)

        logger.train_step += 1

        # policy updates
        train_cond = not conf.display
        for agent in agents:
            if train_cond and len(agent.replay_buffer
                                  ) > conf.batch_size * conf.max_episode_len:
                if logger.train_step % conf.update_rate == 0:  # only update every 100 steps
                    q_loss, pol_loss = agent.update(agents, logger.train_step)

        # for displaying learned policies
        if conf.display:
            time.sleep(0.1)
            env.render()

        # saves logger outputs to a file similar to the way in the original MADDPG implementation
        if len(logger.episode_rewards) > conf.num_episodes:
            logger.experiment_end()
            return logger.get_sacred_results()