Пример #1
0
def main():
    args = get_config()

    # seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # cuda
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.set_num_threads(args.n_training_threads)
        if args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        device = torch.device("cpu")
        torch.set_num_threads(args.n_training_threads)

    # path
    model_dir = Path('./results') / args.env_name / args.algorithm_name
    if not model_dir.exists():
        curr_run = 'run1'
    else:
        exst_run_nums = [
            int(str(folder.name).split('run')[1])
            for folder in model_dir.iterdir()
            if str(folder.name).startswith('run')
        ]
        if len(exst_run_nums) == 0:
            curr_run = 'run1'
        else:
            curr_run = 'run%i' % (max(exst_run_nums) + 1)

    run_dir = model_dir / curr_run
    log_dir = run_dir / 'logs'
    save_dir = run_dir / 'models'
    os.makedirs(str(log_dir))
    os.makedirs(str(save_dir))
    logger = SummaryWriter(str(log_dir))

    # env
    envs = make_parallel_env(args)
    #Policy network
    actor_critic = []
    if args.share_policy:
        ac = Policy(envs.observation_space[0],
                    envs.action_space[0],
                    num_agents=args.num_agents,
                    base_kwargs={
                        'lstm': args.lstm,
                        'naive_recurrent': args.naive_recurrent_policy,
                        'recurrent': args.recurrent_policy,
                        'hidden_size': args.hidden_size
                    })
        ac.to(device)
        for agent_id in range(args.num_agents):
            actor_critic.append(ac)
    else:
        for agent_id in range(args.num_agents):
            ac = Policy(envs.observation_space[0],
                        envs.action_space[0],
                        num_agents=args.num_agents,
                        base_kwargs={
                            'naive_recurrent': args.naive_recurrent_policy,
                            'recurrent': args.recurrent_policy,
                            'hidden_size': args.hidden_size
                        })
            ac.to(device)
            actor_critic.append(ac)

    agents = []
    rollouts = []
    for agent_id in range(args.num_agents):
        # algorithm
        agent = PPO(actor_critic[agent_id],
                    agent_id,
                    args.clip_param,
                    args.ppo_epoch,
                    args.num_mini_batch,
                    args.data_chunk_length,
                    args.value_loss_coef,
                    args.entropy_coef,
                    logger,
                    lr=args.lr,
                    eps=args.eps,
                    max_grad_norm=args.max_grad_norm,
                    use_clipped_value_loss=args.use_clipped_value_loss)
        #replay buffer
        ro = RolloutStorage(args.num_agents, agent_id, args.episode_length,
                            args.n_rollout_threads,
                            envs.observation_space[agent_id],
                            envs.action_space[agent_id],
                            actor_critic[agent_id].recurrent_hidden_state_size)
        agents.append(agent)
        rollouts.append(ro)

    # reset env
    obs = envs.reset()
    # rollout
    for i in range(args.num_agents):
        rollouts[i].share_obs[0].copy_(
            torch.tensor(obs.reshape(args.n_rollout_threads, -1)))
        rollouts[i].obs[0].copy_(torch.tensor(obs[:, i, :]))
        rollouts[i].recurrent_hidden_states.zero_()
        rollouts[i].recurrent_hidden_states_critic.zero_()
        rollouts[i].recurrent_c_states.zero_()
        rollouts[i].recurrent_c_states_critic.zero_()
        rollouts[i].to(device)

    # run
    coop_num = []
    defect_num = []
    coopdefect_num = []
    defectcoop_num = []
    gore1_num = []
    gore2_num = []
    gore3_num = []
    hare1_num = []
    hare2_num = []
    hare3_num = []
    collective_return = []
    apple_consumption = []
    waste_cleared = []
    sustainability = []
    fire = []

    start = time.time()
    episodes = int(
        args.num_env_steps) // args.episode_length // args.n_rollout_threads
    all_episode = 0

    for episode in range(episodes):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            for i in range(args.num_agents):
                update_linear_schedule(agents[i].optimizer, episode, episodes,
                                       args.lr)

        for step in range(args.episode_length):
            # Sample actions
            values = []
            actions = []
            action_log_probs = []
            recurrent_hidden_statess = []
            recurrent_hidden_statess_critic = []
            recurrent_c_statess = []
            recurrent_c_statess_critic = []

            with torch.no_grad():
                for i in range(args.num_agents):
                    value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic,\
                        recurrent_c_states, recurrent_c_states_critic =\
                            actor_critic[i].act(rollouts[i].share_obs[step],
                                                        rollouts[i].obs[step],
                                                        rollouts[i].recurrent_hidden_states[step],
                                                        rollouts[i].recurrent_hidden_states_critic[step],
                                                        rollouts[i].recurrent_c_states[step],
                                                        rollouts[i].recurrent_c_states_critic[step],
                                                        rollouts[i].masks[step])
                    values.append(value)
                    actions.append(action)
                    action_log_probs.append(action_log_prob)
                    recurrent_hidden_statess.append(recurrent_hidden_states)
                    recurrent_hidden_statess_critic.append(
                        recurrent_hidden_states_critic)
                    recurrent_c_statess.append(recurrent_c_states)
                    recurrent_c_statess_critic.append(
                        recurrent_c_states_critic)

            # rearrange action
            actions_env = []
            for i in range(args.n_rollout_threads):
                one_hot_action_env = []
                for k in range(args.num_agents):
                    one_hot_action = np.zeros(envs.action_space[0].n)
                    one_hot_action[actions[k][i]] = 1
                    one_hot_action_env.append(one_hot_action)
                actions_env.append(one_hot_action_env)

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(actions_env)

            # If done then clean the history of observations.
            # insert data in buffer
            masks = []
            bad_masks = []
            for i in range(args.num_agents):
                mask = []
                bad_mask = []
                for done_ in done:
                    if done_[i]:
                        mask.append([0.0])
                        bad_mask.append([1.0])
                    else:
                        mask.append([1.0])
                        bad_mask.append([1.0])
                masks.append(torch.FloatTensor(mask))
                bad_masks.append(torch.FloatTensor(bad_mask))

            for i in range(args.num_agents):
                rollouts[i].insert(
                    torch.tensor(obs.reshape(args.n_rollout_threads, -1)),
                    torch.tensor(obs[:, i, :]), recurrent_hidden_statess[i],
                    recurrent_hidden_statess_critic[i], recurrent_c_statess[i],
                    recurrent_c_statess_critic[i], actions[i],
                    action_log_probs[i], values[i],
                    torch.tensor(reward[:,
                                        i].reshape(-1,
                                                   1)), masks[i], bad_masks[i])

        with torch.no_grad():
            next_values = []
            for i in range(args.num_agents):
                next_value = actor_critic[i].get_value(
                    rollouts[i].share_obs[-1], rollouts[i].obs[-1],
                    rollouts[i].recurrent_hidden_states[-1],
                    rollouts[i].recurrent_hidden_states_critic[-1],
                    rollouts[i].recurrent_c_states[-1],
                    rollouts[i].recurrent_c_states_critic[-1],
                    rollouts[i].masks[-1]).detach()
                next_values.append(next_value)

        for i in range(args.num_agents):
            rollouts[i].compute_returns(next_values[i], args.use_gae,
                                        args.gamma, args.gae_lambda,
                                        args.use_proper_time_limits)

        # update the network
        value_losses = []
        action_losses = []
        dist_entropies = []
        for i in range(args.num_agents):
            value_loss, action_loss, dist_entropy = agents[i].update(
                rollouts[i])
            value_losses.append(value_loss)
            action_losses.append(action_loss)
            dist_entropies.append(dist_entropy)

        if args.env_name == "StagHunt":
            for info in infos:
                if 'coop&coop_num' in info.keys():
                    coop_num.append(info['coop&coop_num'])
                if 'defect&defect_num' in info.keys():
                    defect_num.append(info['defect&defect_num'])
                if 'coop&defect_num' in info.keys():
                    coopdefect_num.append(info['coop&defect_num'])
                if 'defect&coop_num' in info.keys():
                    defectcoop_num.append(info['defect&coop_num'])

            for i in range(args.n_rollout_threads):
                logger.add_scalars(
                    'coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'defect&defect_num_per_episode',
                    {'defect&defect_num_per_episode': defect_num[all_episode]},
                    all_episode)
                logger.add_scalars('coop&defect_num_per_episode', {
                    'coop&defect_num_per_episode':
                    coopdefect_num[all_episode]
                }, all_episode)
                logger.add_scalars('defect&coop_num_per_episode', {
                    'defect&coop_num_per_episode':
                    defectcoop_num[all_episode]
                }, all_episode)
                all_episode += 1
        elif args.env_name == "StagHuntGW":
            for info in infos:
                if 'collective_return' in info.keys():
                    collective_return.append(info['collective_return'])
                if 'coop&coop_num' in info.keys():
                    coop_num.append(info['coop&coop_num'])
                if 'gore1_num' in info.keys():
                    gore1_num.append(info['gore1_num'])
                if 'gore2_num' in info.keys():
                    gore2_num.append(info['gore2_num'])
                if 'hare1_num' in info.keys():
                    hare1_num.append(info['hare1_num'])
                if 'hare2_num' in info.keys():
                    hare2_num.append(info['hare2_num'])
            for i in range(args.n_rollout_threads):
                logger.add_scalars(
                    'collective_return',
                    {'collective_return': collective_return[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'gore1_num_per_episode',
                    {'gore1_num_per_episode': gore1_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'gore2_num_per_episode',
                    {'gore2_num_per_episode': gore2_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'hare1_num_per_episode',
                    {'hare1_num_per_episode': hare1_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'hare2_num_per_episode',
                    {'hare2_num_per_episode': hare2_num[all_episode]},
                    all_episode)
                all_episode += 1
        elif args.env_name == "EscalationGW":
            for info in infos:
                if 'collective_return' in info.keys():
                    collective_return.append(info['collective_return'])
                if 'coop&coop_num' in info.keys():
                    coop_num.append(info['coop&coop_num'])
            for i in range(args.n_rollout_threads):
                logger.add_scalars(
                    'collective_return',
                    {'collective_return': collective_return[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[all_episode]},
                    all_episode)
                all_episode += 1
        elif args.env_name == "multi_StagHuntGW":
            for info in infos:
                if 'collective_return' in info.keys():
                    collective_return.append(info['collective_return'])
                if 'coop&coop_num' in info.keys():
                    coop_num.append(info['coop&coop_num'])
                if 'gore0_num' in info.keys():
                    gore1_num.append(info['gore0_num'])
                if 'gore1_num' in info.keys():
                    gore2_num.append(info['gore1_num'])
                if 'gore2_num' in info.keys():
                    gore3_num.append(info['gore2_num'])
                if 'hare0_num' in info.keys():
                    hare1_num.append(info['hare0_num'])
                if 'hare1_num' in info.keys():
                    hare2_num.append(info['hare1_num'])
                if 'hare2_num' in info.keys():
                    hare3_num.append(info['hare2_num'])
            for i in range(args.n_rollout_threads):
                logger.add_scalars(
                    'collective_return',
                    {'collective_return': collective_return[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'gore1_num_per_episode',
                    {'gore1_num_per_episode': gore1_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'gore2_num_per_episode',
                    {'gore2_num_per_episode': gore2_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'gore3_num_per_episode',
                    {'gore3_num_per_episode': gore3_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'hare1_num_per_episode',
                    {'hare1_num_per_episode': hare1_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'hare2_num_per_episode',
                    {'hare2_num_per_episode': hare2_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'hare3_num_per_episode',
                    {'hare3_num_per_episode': hare3_num[all_episode]},
                    all_episode)
                all_episode += 1

        # clean the buffer and reset
        obs = envs.reset()
        for i in range(args.num_agents):
            rollouts[i].share_obs[0].copy_(
                torch.tensor(obs.reshape(args.n_rollout_threads, -1)))
            rollouts[i].obs[0].copy_(torch.tensor(obs[:, i, :]))
            rollouts[i].recurrent_hidden_states.zero_()
            rollouts[i].recurrent_hidden_states_critic.zero_()
            rollouts[i].recurrent_c_states.zero_()
            rollouts[i].recurrent_c_states_critic.zero_()
            rollouts[i].masks[0].copy_(torch.ones(args.n_rollout_threads, 1))
            rollouts[i].bad_masks[0].copy_(
                torch.ones(args.n_rollout_threads, 1))
            rollouts[i].to(device)

        for i in range(args.num_agents):
            # save for every interval-th episode or for the last epoch
            if (episode % args.save_interval == 0 or episode == episodes - 1):
                torch.save({'model': actor_critic[i]},
                           str(save_dir) + "/agent%i_model" % i + ".pt")

        # log information
        if episode % args.log_interval == 0:
            total_num_steps = (
                episode + 1) * args.episode_length * args.n_rollout_threads
            end = time.time()
            print(
                "\n Updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                .format(episode, episodes, total_num_steps, args.num_env_steps,
                        int(total_num_steps / (end - start))))
            for i in range(args.num_agents):
                print("value loss of agent%i: " % i + str(value_losses[i]))
    logger.export_scalars_to_json(str(log_dir / 'summary.json'))
    logger.close()

    ###----------------------------------------------------------###
    ###----------------------------------------------------------###
    ###----------------------------------------------------------###
    if args.eval:
        eval_dir = run_dir / 'eval'
        log_dir = eval_dir / 'logs'
        os.makedirs(str(log_dir))
        logger = SummaryWriter(str(log_dir))

        # eval best policy
        eval_rewards = []
        # env
        if args.env_name == "StagHunt":
            assert args.num_agents == 2, (
                "only 2 agents is supported, check the config.py.")
            env = MGEnv(args)
        elif args.env_name == "StagHuntGW" or args.env_name == "EscalationGW":
            assert args.num_agents == 2, (
                "only 2 agent is supported in single navigation, check the config.py."
            )
            env = GridWorldEnv(args)
        elif args.env_name == "multi_StagHuntGW":
            env = multi_GridWorldEnv(args)
        else:
            print("Can not support the " + args.env_name + "environment.")
            raise NotImplementedError

        #Policy network
        coop_num = []
        defect_num = []
        coopdefect_num = []
        defectcoop_num = []
        gore1_num = []
        gore2_num = []
        gore3_num = []
        hare1_num = []
        hare2_num = []
        hare3_num = []
        collective_return = []
        apple_consumption = []
        waste_cleared = []
        sustainability = []
        fire = []

        for episode in range(args.eval_episodes):
            print("Episode %i of %i" % (episode, args.eval_episodes))
            state = env.reset()
            state = np.array([state])

            share_obs = []
            obs = []
            recurrent_hidden_statess = []
            recurrent_hidden_statess_critic = []
            recurrent_c_statess = []
            recurrent_c_statess_critic = []
            masks = []
            policy_reward = 0

            # rollout
            for i in range(args.num_agents):
                share_obs.append(
                    (torch.tensor(state.reshape(1, -1),
                                  dtype=torch.float32)).to(device))
                obs.append((torch.tensor(state[:, i, :],
                                         dtype=torch.float32)).to(device))
                recurrent_hidden_statess.append(
                    torch.zeros(
                        1, actor_critic[i].recurrent_hidden_state_size).to(
                            device))
                recurrent_hidden_statess_critic.append(
                    torch.zeros(
                        1, actor_critic[i].recurrent_hidden_state_size).to(
                            device))
                recurrent_c_statess.append(
                    torch.zeros(
                        1, actor_critic[i].recurrent_hidden_state_size).to(
                            device))
                recurrent_c_statess_critic.append(
                    torch.zeros(
                        1, actor_critic[i].recurrent_hidden_state_size).to(
                            device))
                masks.append(torch.ones(1, 1).to(device))

            for step in range(args.episode_length):
                print("step %i of %i" % (step, args.episode_length))
                # Sample actions
                one_hot_actions = []
                for i in range(args.num_agents):
                    one_hot_action = np.zeros(env.action_space[0].n)
                    with torch.no_grad():
                        value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic, recurrent_c_states, recurrent_c_states_critic = actor_critic[
                            i].act(share_obs[i], obs[i],
                                   recurrent_hidden_statess[i],
                                   recurrent_hidden_statess_critic[i],
                                   recurrent_c_statess[i],
                                   recurrent_c_statess_critic[i], masks[i])
                    recurrent_hidden_statess[i].copy_(recurrent_hidden_states)
                    recurrent_hidden_statess_critic[i].copy_(
                        recurrent_hidden_states_critic)
                    recurrent_c_statess[i].copy_(recurrent_c_states)
                    recurrent_c_statess_critic[i].copy_(
                        recurrent_c_states_critic)
                    one_hot_action[action] = 1
                    one_hot_actions.append(one_hot_action)

                # Obser reward and next obs
                state, reward, done, infos = env.step(one_hot_actions)

                for i in range(args.num_agents):
                    print("Reward of agent%i: " % i + str(reward[i]))
                    policy_reward += reward[i]

                if all(done):
                    break

                state = np.array([state])

                for i in range(args.num_agents):
                    if len(env.observation_space[0]) == 1:
                        share_obs[i].copy_(
                            torch.tensor(state.reshape(1, -1),
                                         dtype=torch.float32))
                        obs[i].copy_(
                            torch.tensor(state[:, i, :], dtype=torch.float32))
                    elif len(env.observation_space[0]) == 3:
                        share_obs[i].copy_(
                            torch.tensor(state.reshape(
                                1, -1, env.observation_space[0][1],
                                env.observation_space[0][2]),
                                         dtype=torch.float32))
                        obs[i].copy_(
                            torch.tensor(state[:, i, :, :, :],
                                         dtype=torch.float32))

            eval_rewards.append(policy_reward)

            if args.env_name == "StagHunt":
                if 'coop&coop_num' in infos.keys():
                    coop_num.append(infos['coop&coop_num'])
                if 'defect&defect_num' in infos.keys():
                    defect_num.append(infos['defect&defect_num'])
                if 'coop&defect_num' in infos.keys():
                    coopdefect_num.append(infos['coop&defect_num'])
                if 'defect&coop_num' in infos.keys():
                    defectcoop_num.append(infos['defect&coop_num'])

                logger.add_scalars(
                    'coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[episode]}, episode)
                logger.add_scalars(
                    'defect&defect_num_per_episode',
                    {'defect&defect_num_per_episode': defect_num[episode]},
                    episode)
                logger.add_scalars(
                    'coop&defect_num_per_episode',
                    {'coop&defect_num_per_episode': coopdefect_num[episode]},
                    episode)
                logger.add_scalars(
                    'defect&coop_num_per_episode',
                    {'defect&coop_num_per_episode': defectcoop_num[episode]},
                    episode)

            elif args.env_name == "StagHuntGW":
                if 'collective_return' in infos.keys():
                    collective_return.append(infos['collective_return'])
                    logger.add_scalars(
                        'collective_return',
                        {'collective_return': collective_return[episode]},
                        episode)
                if 'coop&coop_num' in infos.keys():
                    coop_num.append(infos['coop&coop_num'])
                    logger.add_scalars(
                        'coop&coop_num_per_episode',
                        {'coop&coop_num_per_episode': coop_num[episode]},
                        episode)
                if 'gore1_num' in infos.keys():
                    gore1_num.append(infos['gore1_num'])
                    logger.add_scalars(
                        'gore1_num_per_episode',
                        {'gore1_num_per_episode': gore1_num[episode]}, episode)
                if 'gore2_num' in infos.keys():
                    gore2_num.append(infos['gore2_num'])
                    logger.add_scalars(
                        'gore2_num_per_episode',
                        {'gore2_num_per_episode': gore2_num[episode]}, episode)
                if 'hare1_num' in infos.keys():
                    hare1_num.append(infos['hare1_num'])
                    logger.add_scalars(
                        'hare1_num_per_episode',
                        {'hare1_num_per_episode': hare1_num[episode]}, episode)
                if 'hare2_num' in infos.keys():
                    hare2_num.append(infos['hare2_num'])
                    logger.add_scalars(
                        'hare2_num_per_episode',
                        {'hare2_num_per_episode': hare2_num[episode]}, episode)
            elif args.env_name == "EscalationGW":
                if 'collective_return' in infos.keys():
                    collective_return.append(infos['collective_return'])
                    logger.add_scalars(
                        'collective_return',
                        {'collective_return': collective_return[episode]},
                        episode)
                if 'coop&coop_num' in infos.keys():
                    coop_num.append(infos['coop&coop_num'])
                    logger.add_scalars(
                        'coop&coop_num_per_episode',
                        {'coop&coop_num_per_episode': coop_num[episode]},
                        episode)
            elif args.env_name == "multi_StagHuntGW":
                if 'collective_return' in infos.keys():
                    collective_return.append(infos['collective_return'])
                    logger.add_scalars(
                        'collective_return',
                        {'collective_return': collective_return[episode]},
                        episode)
                if 'coop&coop_num' in infos.keys():
                    coop_num.append(infos['coop&coop_num'])
                    logger.add_scalars(
                        'coop&coop_num_per_episode',
                        {'coop&coop_num_per_episode': coop_num[episode]},
                        episode)
                if 'gore0_num' in infos.keys():
                    gore1_num.append(infos['gore0_num'])
                    logger.add_scalars(
                        'gore1_num_per_episode',
                        {'gore1_num_per_episode': gore1_num[episode]}, episode)
                if 'gore1_num' in infos.keys():
                    gore2_num.append(infos['gore1_num'])
                    logger.add_scalars(
                        'gore2_num_per_episode',
                        {'gore2_num_per_episode': gore2_num[episode]}, episode)
                if 'gore2_num' in infos.keys():
                    gore3_num.append(infos['gore2_num'])
                    logger.add_scalars(
                        'gore3_num_per_episode',
                        {'gore3_num_per_episode': gore3_num[episode]}, episode)
                if 'hare0_num' in infos.keys():
                    hare1_num.append(infos['hare0_num'])
                    logger.add_scalars(
                        'hare1_num_per_episode',
                        {'hare1_num_per_episode': hare1_num[episode]}, episode)
                if 'hare1_num' in infos.keys():
                    hare2_num.append(infos['hare1_num'])
                    logger.add_scalars(
                        'hare2_num_per_episode',
                        {'hare2_num_per_episode': hare2_num[episode]}, episode)
                if 'hare2_num' in infos.keys():
                    hare3_num.append(infos['hare2_num'])
                    logger.add_scalars(
                        'hare3_num_per_episode',
                        {'hare3_num_per_episode': hare3_num[episode]}, episode)
        logger.export_scalars_to_json(str(log_dir / 'summary.json'))
        logger.close()
Пример #2
0
def main():
    args = get_config()

    # cuda
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.set_num_threads(args.n_training_threads)
        if args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        device = torch.device("cpu")
        torch.set_num_threads(args.n_training_threads)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # path
    model_dir = Path('./results') / args.env_name / args.algorithm_name / (
        "run" + str(args.seed))
    if args.critic_full_obs:
        run_dir = model_dir / 'adaptive'
    else:
        run_dir = model_dir / 'adaptive_only'
    log_dir = run_dir / 'logs'
    save_dir = run_dir / 'models'
    os.makedirs(str(log_dir))
    os.makedirs(str(save_dir))
    logger = SummaryWriter(str(log_dir))

    print(
        "\n Now we have %i fixed policy! Train Single Adaptive Policy... \n" %
        args.num_policy_candidates)
    args.env_name = args.env_name + "Adaptive"
    policy_candidates = []
    for i in range(args.num_policy_candidates):
        ac = torch.load(
            str(model_dir) + ("/models/Policy%i" % (i + 1)) +
            "-agent0_model.pt")['model'].cpu()
        policy_candidates.append(ac)

    # env
    envs = make_parallel_env(args, policy_candidates)

    #Policy network
    # agent 0
    actor_critic = Policy(envs.observation_space[0],
                          envs.action_space[0],
                          num_agents=args.num_agents,
                          base_kwargs={
                              'lstm': args.lstm,
                              'naive_recurrent': args.naive_recurrent_policy,
                              'recurrent': args.recurrent_policy,
                              'hidden_size': args.hidden_size
                          })

    actor_critic.to(device)
    agent0 = PPO(actor_critic,
                 0,
                 args.clip_param,
                 args.ppo_epoch,
                 args.num_mini_batch,
                 args.data_chunk_length,
                 args.value_loss_coef,
                 args.entropy_coef,
                 logger,
                 lr=args.lr,
                 eps=args.eps,
                 max_grad_norm=args.max_grad_norm,
                 use_clipped_value_loss=args.use_clipped_value_loss)
    #replay buffer
    rollout = RolloutStorage(args.num_agents, 0, args.episode_length,
                             args.n_rollout_threads, envs.observation_space[0],
                             envs.action_space[0],
                             actor_critic.recurrent_hidden_state_size)

    # reset
    if args.critic_full_obs:
        obs, obs_critic, select_opponent = envs.reset()
    else:
        obs, select_opponent = envs.reset()

    # rollout
    if len(envs.observation_space[0]) == 1:
        if args.critic_full_obs:
            rollout.share_obs[0].copy_(
                torch.tensor(obs_critic.reshape(args.n_rollout_threads, -1)))
        else:
            rollout.share_obs[0].copy_(
                torch.tensor(obs.reshape(args.n_rollout_threads, -1)))
        rollout.obs[0].copy_(torch.tensor(obs[:, 0, :]))
        rollout.recurrent_hidden_states.zero_()
        rollout.recurrent_hidden_states_critic.zero_()
        rollout.recurrent_c_states.zero_()
        rollout.recurrent_c_states_critic.zero_()
    else:
        raise NotImplementedError
    rollout.to(device)

    # run
    collective_return = []
    apple_consumption = []
    waste_cleared = []
    sustainability = []
    fire = []

    start = time.time()
    episodes = int(
        args.num_env_steps) // args.episode_length // args.n_rollout_threads
    all_episode = 0
    all_episode_adaptive = np.zeros(args.num_policy_candidates)

    for episode in range(episodes):
        if args.use_linear_lr_decay:
            update_linear_schedule(agent0.optimizer, episode, episodes,
                                   args.lr)

        for step in range(args.episode_length):
            with torch.no_grad():
                value, action0, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic, recurrent_c_states, recurrent_c_states_critic = actor_critic.act(
                    rollout.share_obs[step], rollout.obs[step],
                    rollout.recurrent_hidden_states[step],
                    rollout.recurrent_hidden_states_critic[step],
                    rollout.recurrent_c_states[step],
                    rollout.recurrent_c_states_critic[step],
                    rollout.masks[step])

            # rearrange action
            actions_env = []
            for i in range(args.n_rollout_threads):
                one_hot_action = np.zeros((1, envs.action_space[0].n))
                one_hot_action[0][action0[i]] = 1
                actions_env.append(one_hot_action)

            # Obser reward and next obs
            if args.critic_full_obs:
                obs, obs_critic, select_opponent, reward, done, infos = envs.step(
                    actions_env)
            else:
                obs, select_opponent, reward, done, infos = envs.step(
                    actions_env)

            # If done then clean the history of observations.
            # insert data in buffer
            masks = []
            bad_masks = []
            for i in range(args.num_agents):
                mask = []
                bad_mask = []
                for done_ in done:
                    if done_[i]:
                        mask.append([0.0])
                        bad_mask.append([1.0])
                    else:
                        mask.append([1.0])
                        bad_mask.append([1.0])
                masks.append(torch.FloatTensor(mask))
                bad_masks.append(torch.FloatTensor(bad_mask))

            if len(envs.observation_space[0]) == 1:
                if args.critic_full_obs:
                    rollout.insert(
                        torch.tensor(
                            obs_critic.reshape(args.n_rollout_threads, -1)),
                        torch.tensor(obs[:, 0, :]), recurrent_hidden_states,
                        recurrent_hidden_states_critic, recurrent_c_states,
                        recurrent_c_states_critic, action0,
                        action_log_prob, value,
                        torch.tensor(reward[:, 0].reshape(-1, 1)), masks[0],
                        bad_masks[0])
                else:
                    rollout.insert(
                        torch.tensor(obs.reshape(args.n_rollout_threads, -1)),
                        torch.tensor(obs[:, 0, :]), recurrent_hidden_states,
                        recurrent_hidden_states_critic, recurrent_c_states,
                        recurrent_c_states_critic, action0,
                        action_log_prob, value,
                        torch.tensor(reward[:, 0].reshape(-1, 1)), masks[0],
                        bad_masks[0])
            else:
                raise NotImplementedError

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollout.share_obs[-1], rollout.obs[-1],
                rollout.recurrent_hidden_states[-1],
                rollout.recurrent_hidden_states_critic[-1],
                rollout.recurrent_c_states[-1],
                rollout.recurrent_c_states_critic[-1],
                rollout.masks[-1]).detach()

        rollout.compute_returns(next_value, args.use_gae, args.gamma,
                                args.gae_lambda, args.use_proper_time_limits)

        # update the network
        value_loss, action_loss, dist_entropy = agent0.update(rollout)

        if args.env_name == "StagHuntAdaptive":
            coop_num = []
            defect_num = []
            coopdefect_num = []
            defectcoop_num = []
            rew = []
            for info in infos:
                if 'coop&coop_num' in info.keys():
                    coop_num.append(info['coop&coop_num'])
                if 'defect&defect_num' in info.keys():
                    defect_num.append(info['defect&defect_num'])
                if 'coop&defect_num' in info.keys():
                    coopdefect_num.append(info['coop&defect_num'])
                if 'defect&coop_num' in info.keys():
                    defectcoop_num.append(info['defect&coop_num'])
            for i in range(args.n_rollout_threads):
                rew.append(rollout.rewards[:, i, :].sum().cpu().numpy())

            for i in range(args.n_rollout_threads):
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/defect&defect_num_per_episode',
                    {'defect&defect_num_per_episode': defect_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/coop&defect_num_per_episode',
                    {'coop&defect_num_per_episode': coopdefect_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/defect&coop_num_per_episode',
                    {'defect&coop_num_per_episode': defectcoop_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) + '/reward',
                    {'reward': np.mean(np.array(rew[i]))},
                    all_episode_adaptive[select_opponent[i]])
                all_episode_adaptive[select_opponent[i]] += 1
        elif args.env_name == "StagHuntGWAdaptive":
            collective_return = []
            coop_num = []
            gore1_num = []
            gore2_num = []
            hare1_num = []
            hare2_num = []
            for info in infos:
                if 'collective_return' in info.keys():
                    collective_return.append(info['collective_return'])
                if 'coop&coop_num' in info.keys():
                    coop_num.append(info['coop&coop_num'])
                if 'gore1_num' in info.keys():
                    gore1_num.append(info['gore1_num'])
                if 'gore2_num' in info.keys():
                    gore2_num.append(info['gore2_num'])
                if 'hare1_num' in info.keys():
                    hare1_num.append(info['hare1_num'])
                if 'hare2_num' in info.keys():
                    hare2_num.append(info['hare2_num'])

            for i in range(args.n_rollout_threads):
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/collective_return',
                    {'collective_return': collective_return[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/gore1_num_per_episode',
                    {'gore1_num_per_episode': gore1_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/gore2_num_per_episode',
                    {'gore2_num_per_episode': gore2_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/hare1_num_per_episode',
                    {'hare1_num_per_episode': hare1_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                logger.add_scalars(
                    'Policy-' + str(select_opponent[i] + 1) +
                    '/hare2_num_per_episode',
                    {'hare2_num_per_episode': hare2_num[i]},
                    all_episode_adaptive[select_opponent[i]])
                all_episode_adaptive[select_opponent[i]] += 1

        if args.critic_full_obs:
            obs, obs_critic, select_opponent = envs.reset()
        else:
            obs, select_opponent = envs.reset()

        if len(envs.observation_space[0]) == 1:
            if args.critic_full_obs:
                rollout.share_obs[0].copy_(
                    torch.tensor(obs_critic.reshape(args.n_rollout_threads,
                                                    -1)))
            else:
                rollout.share_obs[0].copy_(
                    torch.tensor(obs.reshape(args.n_rollout_threads, -1)))
            rollout.obs[0].copy_(torch.tensor(obs[:, 0, :]))
            rollout.recurrent_hidden_states.zero_()
            rollout.recurrent_hidden_states_critic.zero_()
            rollout.recurrent_c_states.zero_()
            rollout.recurrent_c_states_critic.zero_()
            rollout.masks[0].copy_(torch.ones(args.n_rollout_threads, 1))
            rollout.bad_masks[0].copy_(torch.ones(args.n_rollout_threads, 1))
        else:
            raise NotImplementedError
        rollout.to(device)

        if (episode % args.save_interval == 0 or episode == episodes - 1):
            torch.save({'model': actor_critic},
                       str(save_dir) + "/agent0_model.pt")

        # log information
        if episode % args.log_interval == 0:
            total_num_steps = (
                episode + 1) * args.episode_length * args.n_rollout_threads
            end = time.time()
            print(
                "\n Updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                .format(episode, episodes, total_num_steps, args.num_env_steps,
                        int(total_num_steps / (end - start))))
            print("value loss: agent0--" + str(value_loss))
    logger.export_scalars_to_json(str(log_dir / 'summary.json'))
    logger.close()
def main():
    args = get_config()

    # seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # cuda
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.set_num_threads(args.n_training_threads)
        if args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        device = torch.device("cpu")
        torch.set_num_threads(args.n_training_threads)

    # path
    model_dir = Path(
        './results') / args.env_name / args.scenario_name / args.algorithm_name
    if not model_dir.exists():
        curr_run = 'run1'
    else:
        exst_run_nums = [
            int(str(folder.name).split('run')[1])
            for folder in model_dir.iterdir()
            if str(folder.name).startswith('run')
        ]
        if len(exst_run_nums) == 0:
            curr_run = 'run1'
        else:
            curr_run = 'run%i' % (max(exst_run_nums) + 1)

    run_dir = model_dir / curr_run
    log_dir = run_dir / 'logs'
    save_dir = run_dir / 'models'
    os.makedirs(str(log_dir))
    os.makedirs(str(save_dir))
    logger = SummaryWriter(str(log_dir))

    # env
    envs = make_parallel_env(args)
    if args.eval:
        eval_env = make_eval_env(args)

    num_agents = args.num_agents
    all_action_space = []
    all_obs_space = []
    action_movement_dim = []
    if args.env_name == "BlueprintConstruction":
        order_obs = [
            'agent_qpos_qvel', 'box_obs', 'ramp_obs', 'construction_site_obs',
            'observation_self'
        ]
        mask_order_obs = [None, None, None, None, None]
    elif args.env_name == "BoxLocking":
        order_obs = [
            'agent_qpos_qvel', 'box_obs', 'ramp_obs', 'observation_self'
        ]
        mask_order_obs = ['mask_aa_obs', 'mask_ab_obs', 'mask_ar_obs', None]
    else:
        print("Can not support the " + args.env_name + "environment.")
        raise NotImplementedError

    for agent_id in range(num_agents):
        # deal with dict action space
        action_movement = envs.action_space['action_movement'][agent_id].nvec
        action_movement_dim.append(len(action_movement))
        action_glueall = envs.action_space['action_glueall'][agent_id].n
        action_vec = np.append(action_movement, action_glueall)
        if 'action_pull' in envs.action_space.spaces.keys():
            action_pull = envs.action_space['action_pull'][agent_id].n
            action_vec = np.append(action_vec, action_pull)
        action_space = MultiDiscrete([[0, vec - 1] for vec in action_vec])
        all_action_space.append(action_space)
        # deal with dict obs space
        obs_space = []
        obs_dim = 0
        for key in order_obs:
            if key in envs.observation_space.spaces.keys():
                space = list(envs.observation_space[key].shape)
                if len(space) < 2:
                    space.insert(0, 1)
                obs_space.append(space)
                obs_dim += reduce(lambda x, y: x * y, space)
        obs_space.insert(0, obs_dim)
        all_obs_space.append(obs_space)

    if args.share_policy:
        actor_critic = Policy(all_obs_space[0],
                              all_action_space[0],
                              num_agents=num_agents,
                              gain=args.gain,
                              base_kwargs={
                                  'naive_recurrent':
                                  args.naive_recurrent_policy,
                                  'recurrent': args.recurrent_policy,
                                  'hidden_size': args.hidden_size,
                                  'recurrent_N': args.recurrent_N,
                                  'attn': args.attn,
                                  'attn_only_critic': args.attn_only_critic,
                                  'attn_size': args.attn_size,
                                  'attn_N': args.attn_N,
                                  'attn_heads': args.attn_heads,
                                  'dropout': args.dropout,
                                  'use_average_pool': args.use_average_pool,
                                  'use_common_layer': args.use_common_layer,
                                  'use_feature_normlization':
                                  args.use_feature_normlization,
                                  'use_feature_popart':
                                  args.use_feature_popart,
                                  'use_orthogonal': args.use_orthogonal,
                                  'layer_N': args.layer_N,
                                  'use_ReLU': args.use_ReLU,
                                  'use_same_dim': True
                              },
                              device=device)
        actor_critic.to(device)
        # algorithm
        agents = PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.data_chunk_length,
                     args.value_loss_coef,
                     args.entropy_coef,
                     logger,
                     lr=args.lr,
                     eps=args.eps,
                     weight_decay=args.weight_decay,
                     max_grad_norm=args.max_grad_norm,
                     use_max_grad_norm=args.use_max_grad_norm,
                     use_clipped_value_loss=args.use_clipped_value_loss,
                     use_common_layer=args.use_common_layer,
                     use_huber_loss=args.use_huber_loss,
                     huber_delta=args.huber_delta,
                     use_popart=args.use_popart,
                     use_value_high_masks=args.use_value_high_masks,
                     device=device)

        #replay buffer
        rollouts = RolloutStorage(num_agents,
                                  args.episode_length,
                                  args.n_rollout_threads,
                                  all_obs_space[0],
                                  all_action_space[0],
                                  args.hidden_size,
                                  use_same_dim=True)
    else:
        actor_critic = []
        agents = []
        for agent_id in range(num_agents):
            ac = Policy(all_obs_space[0],
                        all_action_space[0],
                        num_agents=num_agents,
                        gain=args.gain,
                        base_kwargs={
                            'naive_recurrent': args.naive_recurrent_policy,
                            'recurrent': args.recurrent_policy,
                            'hidden_size': args.hidden_size,
                            'recurrent_N': args.recurrent_N,
                            'attn': args.attn,
                            'attn_only_critic': args.attn_only_critic,
                            'attn_size': args.attn_size,
                            'attn_N': args.attn_N,
                            'attn_heads': args.attn_heads,
                            'dropout': args.dropout,
                            'use_average_pool': args.use_average_pool,
                            'use_common_layer': args.use_common_layer,
                            'use_feature_normlization':
                            args.use_feature_normlization,
                            'use_feature_popart': args.use_feature_popart,
                            'use_orthogonal': args.use_orthogonal,
                            'layer_N': args.layer_N,
                            'use_ReLU': args.use_ReLU,
                            'use_same_dim': True
                        },
                        device=device)
            ac.to(device)
            # algorithm
            agent = PPO(ac,
                        args.clip_param,
                        args.ppo_epoch,
                        args.num_mini_batch,
                        args.data_chunk_length,
                        args.value_loss_coef,
                        args.entropy_coef,
                        logger,
                        lr=args.lr,
                        eps=args.eps,
                        weight_decay=args.weight_decay,
                        max_grad_norm=args.max_grad_norm,
                        use_max_grad_norm=args.use_max_grad_norm,
                        use_clipped_value_loss=args.use_clipped_value_loss,
                        use_common_layer=args.use_common_layer,
                        use_huber_loss=args.use_huber_loss,
                        huber_delta=args.huber_delta,
                        use_popart=args.use_popart,
                        use_value_high_masks=args.use_value_high_masks,
                        device=device)

            actor_critic.append(ac)
            agents.append(agent)

        #replay buffer
        rollouts = RolloutStorage(num_agents,
                                  args.episode_length,
                                  args.n_rollout_threads,
                                  all_obs_space[0],
                                  all_action_space[0],
                                  args.hidden_size,
                                  use_same_dim=True)

    # reset env
    dict_obs = envs.reset()
    obs = []
    share_obs = []
    for d_o in dict_obs:
        for i, key in enumerate(order_obs):
            if key in envs.observation_space.spaces.keys():
                if mask_order_obs[i] == None:
                    temp_share_obs = d_o[key].reshape(num_agents, -1).copy()
                    temp_obs = temp_share_obs.copy()
                else:
                    temp_share_obs = d_o[key].reshape(num_agents, -1).copy()
                    temp_mask = d_o[mask_order_obs[i]].copy()
                    temp_obs = d_o[key].copy()
                    mins_temp_mask = ~temp_mask
                    temp_obs[mins_temp_mask] = np.zeros(
                        (mins_temp_mask.sum(), temp_obs.shape[2]))
                    temp_obs = temp_obs.reshape(num_agents, -1)
                if i == 0:
                    reshape_obs = temp_obs.copy()
                    reshape_share_obs = temp_share_obs.copy()
                else:
                    reshape_obs = np.concatenate((reshape_obs, temp_obs),
                                                 axis=1)
                    reshape_share_obs = np.concatenate(
                        (reshape_share_obs, temp_share_obs), axis=1)
        obs.append(reshape_obs)
        share_obs.append(reshape_share_obs)
    obs = np.array(obs)
    share_obs = np.array(share_obs)

    # replay buffer
    rollouts.share_obs[0] = share_obs.copy()
    rollouts.obs[0] = obs.copy()
    rollouts.recurrent_hidden_states = np.zeros(
        rollouts.recurrent_hidden_states.shape).astype(np.float32)
    rollouts.recurrent_hidden_states_critic = np.zeros(
        rollouts.recurrent_hidden_states_critic.shape).astype(np.float32)

    # run
    start = time.time()
    episodes = int(
        args.num_env_steps) // args.episode_length // args.n_rollout_threads
    timesteps = 0

    for episode in range(episodes):
        if args.use_linear_lr_decay:  # decrease learning rate linearly
            if args.share_policy:
                update_linear_schedule(agents.optimizer, episode, episodes,
                                       args.lr)
            else:
                for agent_id in range(num_agents):
                    update_linear_schedule(agents[agent_id].optimizer, episode,
                                           episodes, args.lr)
        # info list
        discard_episode = 0
        success = 0
        trials = 0

        for step in range(args.episode_length):
            # Sample actions
            values = []
            actions = []
            action_log_probs = []
            recurrent_hidden_statess = []
            recurrent_hidden_statess_critic = []
            with torch.no_grad():
                for agent_id in range(num_agents):
                    if args.share_policy:
                        actor_critic.eval()
                        value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic.act(
                            agent_id,
                            torch.tensor(rollouts.share_obs[step, :,
                                                            agent_id]),
                            torch.tensor(rollouts.obs[step, :, agent_id]),
                            torch.tensor(
                                rollouts.recurrent_hidden_states[step, :,
                                                                 agent_id]),
                            torch.tensor(
                                rollouts.recurrent_hidden_states_critic[
                                    step, :, agent_id]),
                            torch.tensor(rollouts.masks[step, :, agent_id]))
                    else:
                        actor_critic[agent_id].eval()
                        value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic[
                            agent_id].act(
                                agent_id,
                                torch.tensor(rollouts.share_obs[step, :,
                                                                agent_id]),
                                torch.tensor(rollouts.obs[step, :, agent_id]),
                                torch.tensor(rollouts.recurrent_hidden_states[
                                    step, :, agent_id]),
                                torch.tensor(
                                    rollouts.recurrent_hidden_states_critic[
                                        step, :, agent_id]),
                                torch.tensor(rollouts.masks[step, :,
                                                            agent_id]))

                    values.append(value.detach().cpu().numpy())
                    actions.append(action.detach().cpu().numpy())
                    action_log_probs.append(
                        action_log_prob.detach().cpu().numpy())
                    recurrent_hidden_statess.append(
                        recurrent_hidden_states.detach().cpu().numpy())
                    recurrent_hidden_statess_critic.append(
                        recurrent_hidden_states_critic.detach().cpu().numpy())

            # rearrange action
            actions_env = []
            for n_rollout_thread in range(args.n_rollout_threads):
                action_movement = []
                action_pull = []
                action_glueall = []
                for agent_id in range(num_agents):
                    action_movement.append(actions[agent_id][n_rollout_thread]
                                           [:action_movement_dim[agent_id]])
                    action_glueall.append(
                        int(actions[agent_id][n_rollout_thread][
                            action_movement_dim[agent_id]]))
                    if 'action_pull' in envs.action_space.spaces.keys():
                        action_pull.append(
                            int(actions[agent_id][n_rollout_thread][-1]))
                action_movement = np.stack(action_movement, axis=0)
                action_glueall = np.stack(action_glueall, axis=0)
                if 'action_pull' in envs.action_space.spaces.keys():
                    action_pull = np.stack(action_pull, axis=0)
                one_env_action = {
                    'action_movement': action_movement,
                    'action_pull': action_pull,
                    'action_glueall': action_glueall
                }
                actions_env.append(one_env_action)

            # Obser reward and next obs
            dict_obs, rewards, dones, infos = envs.step(actions_env)
            if len(rewards.shape) < 3:
                rewards = rewards[:, :, np.newaxis]

            # If done then clean the history of observations.
            # insert data in buffer
            masks = []
            for i, done in enumerate(dones):
                if done:
                    if "discard_episode" in infos[i].keys():
                        if infos[i]['discard_episode']:
                            discard_episode += 1
                        else:
                            trials += 1
                    else:
                        trials += 1
                    if "success" in infos[i].keys():
                        if infos[i]['success']:
                            success += 1
                mask = []
                for agent_id in range(num_agents):
                    if done:
                        recurrent_hidden_statess[agent_id][i] = np.zeros(
                            args.hidden_size).astype(np.float32)
                        recurrent_hidden_statess_critic[agent_id][
                            i] = np.zeros(args.hidden_size).astype(np.float32)
                        mask.append([0.0])
                    else:
                        mask.append([1.0])
                masks.append(mask)

            obs = []
            share_obs = []
            for d_o in dict_obs:
                for i, key in enumerate(order_obs):
                    if key in envs.observation_space.spaces.keys():
                        if mask_order_obs[i] == None:
                            temp_share_obs = d_o[key].reshape(num_agents,
                                                              -1).copy()
                            temp_obs = temp_share_obs.copy()
                        else:
                            temp_share_obs = d_o[key].reshape(num_agents,
                                                              -1).copy()
                            temp_mask = d_o[mask_order_obs[i]].copy()
                            temp_obs = d_o[key].copy()
                            mins_temp_mask = ~temp_mask
                            temp_obs[mins_temp_mask] = np.zeros(
                                (mins_temp_mask.sum(), temp_obs.shape[2]))
                            temp_obs = temp_obs.reshape(num_agents, -1)
                        if i == 0:
                            reshape_obs = temp_obs.copy()
                            reshape_share_obs = temp_share_obs.copy()
                        else:
                            reshape_obs = np.concatenate(
                                (reshape_obs, temp_obs), axis=1)
                            reshape_share_obs = np.concatenate(
                                (reshape_share_obs, temp_share_obs), axis=1)
                obs.append(reshape_obs)
                share_obs.append(reshape_share_obs)
            obs = np.array(obs)
            share_obs = np.array(share_obs)

            rollouts.insert(
                share_obs, obs,
                np.array(recurrent_hidden_statess).transpose(1, 0, 2),
                np.array(recurrent_hidden_statess_critic).transpose(1, 0, 2),
                np.array(actions).transpose(1, 0, 2),
                np.array(action_log_probs).transpose(1, 0, 2),
                np.array(values).transpose(1, 0, 2), rewards, masks)

        with torch.no_grad():
            for agent_id in range(num_agents):
                if args.share_policy:
                    actor_critic.eval()
                    next_value, _, _ = actor_critic.get_value(
                        agent_id,
                        torch.tensor(rollouts.share_obs[-1, :, agent_id]),
                        torch.tensor(rollouts.obs[-1, :, agent_id]),
                        torch.tensor(
                            rollouts.recurrent_hidden_states[-1, :, agent_id]),
                        torch.tensor(
                            rollouts.recurrent_hidden_states_critic[-1, :,
                                                                    agent_id]),
                        torch.tensor(rollouts.masks[-1, :, agent_id]))
                    next_value = next_value.detach().cpu().numpy()
                    rollouts.compute_returns(agent_id, next_value,
                                             args.use_gae, args.gamma,
                                             args.gae_lambda,
                                             args.use_proper_time_limits,
                                             args.use_popart,
                                             agents.value_normalizer)
                else:
                    actor_critic[agent_id].eval()
                    next_value, _, _ = actor_critic[agent_id].get_value(
                        agent_id,
                        torch.tensor(rollouts.share_obs[-1, :, agent_id]),
                        torch.tensor(rollouts.obs[-1, :, agent_id]),
                        torch.tensor(
                            rollouts.recurrent_hidden_states[-1, :, agent_id]),
                        torch.tensor(
                            rollouts.recurrent_hidden_states_critic[-1, :,
                                                                    agent_id]),
                        torch.tensor(rollouts.masks[-1, :, agent_id]))
                    next_value = next_value.detach().cpu().numpy()
                    rollouts.compute_returns(agent_id, next_value,
                                             args.use_gae, args.gamma,
                                             args.gae_lambda,
                                             args.use_proper_time_limits,
                                             args.use_popart,
                                             agents[agent_id].value_normalizer)

        # update the network
        if args.share_policy:
            actor_critic.train()
            value_loss, action_loss, dist_entropy = agents.update_share(
                num_agents, rollouts)

            logger.add_scalars('reward', {'reward': np.mean(rollouts.rewards)},
                               (episode + 1) * args.episode_length *
                               args.n_rollout_threads)
        else:
            value_losses = []
            action_losses = []
            dist_entropies = []

            for agent_id in range(num_agents):
                actor_critic[agent_id].train()
                value_loss, action_loss, dist_entropy = agents[
                    agent_id].update(agent_id, rollouts)
                value_losses.append(value_loss)
                action_losses.append(action_loss)
                dist_entropies.append(dist_entropy)

                logger.add_scalars(
                    'agent%i/reward' % agent_id,
                    {'reward': np.mean(rollouts.rewards[:, :, agent_id])},
                    (episode + 1) * args.episode_length *
                    args.n_rollout_threads)

        # clean the buffer and reset
        rollouts.after_update()

        total_num_steps = (episode +
                           1) * args.episode_length * args.n_rollout_threads

        if (episode % args.save_interval == 0 or episode == episodes -
                1):  # save for every interval-th episode or for the last epoch
            if args.share_policy:
                torch.save({'model': actor_critic},
                           str(save_dir) + "/agent_model.pt")
            else:
                for agent_id in range(num_agents):
                    torch.save({'model': actor_critic[agent_id]},
                               str(save_dir) + "/agent%i_model" % agent_id +
                               ".pt")

        # log information
        if episode % args.log_interval == 0:
            end = time.time()
            print(
                "\n Scenario {} Algo {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                .format(args.scenario_name, args.algorithm_name, episode,
                        episodes, total_num_steps, args.num_env_steps,
                        int(total_num_steps / (end - start))))
            if args.share_policy:
                print("value loss of agent: " + str(value_loss))
            else:
                for agent_id in range(num_agents):
                    print("value loss of agent%i: " % agent_id +
                          str(value_losses[agent_id]))

            logger.add_scalars('discard_episode',
                               {'discard_episode': discard_episode},
                               total_num_steps)
            if trials > 0:
                logger.add_scalars('success_rate',
                                   {'success_rate': success / trials},
                                   total_num_steps)
            else:
                logger.add_scalars('success_rate', {'success_rate': 0.0},
                                   total_num_steps)
        # eval
        if episode % args.eval_interval == 0 and args.eval:
            eval_episode = 0
            eval_success = 0
            eval_dict_obs = eval_env.reset()

            eval_obs = []
            eval_share_obs = []
            for eval_d_o in eval_dict_obs:
                for i, key in enumerate(order_obs):
                    if key in eval_env.observation_space.spaces.keys():
                        if mask_order_obs[i] == None:
                            temp_share_obs = eval_d_o[key].reshape(
                                num_agents, -1).copy()
                            temp_obs = temp_share_obs.copy()
                        else:
                            temp_share_obs = eval_d_o[key].reshape(
                                num_agents, -1).copy()
                            temp_mask = eval_d_o[mask_order_obs[i]].copy()
                            temp_obs = eval_d_o[key].copy()
                            mins_temp_mask = ~temp_mask
                            temp_obs[mins_temp_mask] = np.zeros(
                                (mins_temp_mask.sum(), temp_obs.shape[2]))
                            temp_obs = temp_obs.reshape(num_agents, -1)
                        if i == 0:
                            reshape_obs = temp_obs.copy()
                            reshape_share_obs = temp_share_obs.copy()
                        else:
                            reshape_obs = np.concatenate(
                                (reshape_obs, temp_obs), axis=1)
                            reshape_share_obs = np.concatenate(
                                (reshape_share_obs, temp_share_obs), axis=1)
                eval_obs.append(reshape_obs)
                eval_share_obs.append(reshape_share_obs)
            eval_obs = np.array(eval_obs)
            eval_share_obs = np.array(eval_share_obs)

            eval_recurrent_hidden_states = np.zeros(
                (1, num_agents, args.hidden_size)).astype(np.float32)
            eval_recurrent_hidden_states_critic = np.zeros(
                (1, num_agents, args.hidden_size)).astype(np.float32)
            eval_masks = np.ones((1, num_agents, 1)).astype(np.float32)

            while True:
                eval_actions = []
                actor_critic.eval()
                for agent_id in range(num_agents):
                    _, action, _, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic.act(
                        agent_id,
                        torch.FloatTensor(eval_share_obs[:, agent_id]),
                        torch.FloatTensor(eval_obs[:, agent_id]),
                        torch.FloatTensor(
                            eval_recurrent_hidden_states[:, agent_id]),
                        torch.FloatTensor(
                            eval_recurrent_hidden_states_critic[:, agent_id]),
                        torch.FloatTensor(eval_masks[:, agent_id]),
                        None,
                        deterministic=True)

                    eval_actions.append(action.detach().cpu().numpy())
                    eval_recurrent_hidden_states[:,
                                                 agent_id] = recurrent_hidden_states.detach(
                                                 ).cpu().numpy()
                    eval_recurrent_hidden_states_critic[:,
                                                        agent_id] = recurrent_hidden_states_critic.detach(
                                                        ).cpu().numpy()

                # rearrange action
                eval_actions_env = []
                for n_rollout_thread in range(1):
                    action_movement = []
                    action_pull = []
                    action_glueall = []
                    for agent_id in range(num_agents):
                        action_movement.append(
                            eval_actions[agent_id][n_rollout_thread]
                            [:action_movement_dim[agent_id]])
                        action_glueall.append(
                            int(eval_actions[agent_id][n_rollout_thread][
                                action_movement_dim[agent_id]]))
                        if 'action_pull' in envs.action_space.spaces.keys():
                            action_pull.append(
                                int(eval_actions[agent_id][n_rollout_thread]
                                    [-1]))
                    action_movement = np.stack(action_movement, axis=0)
                    action_glueall = np.stack(action_glueall, axis=0)
                    if 'action_pull' in envs.action_space.spaces.keys():
                        action_pull = np.stack(action_pull, axis=0)
                    one_env_action = {
                        'action_movement': action_movement,
                        'action_pull': action_pull,
                        'action_glueall': action_glueall
                    }
                    eval_actions_env.append(one_env_action)

                # Obser reward and next obs
                eval_dict_obs, eval_rewards, eval_dones, eval_infos = eval_env.step(
                    eval_actions_env)

                eval_obs = []
                eval_share_obs = []
                for eval_d_o in eval_dict_obs:
                    for i, key in enumerate(order_obs):
                        if key in eval_env.observation_space.spaces.keys():
                            if mask_order_obs[i] == None:
                                temp_share_obs = eval_d_o[key].reshape(
                                    num_agents, -1).copy()
                                temp_obs = temp_share_obs.copy()
                            else:
                                temp_share_obs = eval_d_o[key].reshape(
                                    num_agents, -1).copy()
                                temp_mask = eval_d_o[mask_order_obs[i]].copy()
                                temp_obs = eval_d_o[key].copy()
                                mins_temp_mask = ~temp_mask
                                temp_obs[mins_temp_mask] = np.zeros(
                                    (mins_temp_mask.sum(), temp_obs.shape[2]))
                                temp_obs = temp_obs.reshape(num_agents, -1)
                            if i == 0:
                                reshape_obs = temp_obs.copy()
                                reshape_share_obs = temp_share_obs.copy()
                            else:
                                reshape_obs = np.concatenate(
                                    (reshape_obs, temp_obs), axis=1)
                                reshape_share_obs = np.concatenate(
                                    (reshape_share_obs, temp_share_obs),
                                    axis=1)
                    eval_obs.append(reshape_obs)
                    eval_share_obs.append(reshape_share_obs)
                eval_obs = np.array(eval_obs)
                eval_share_obs = np.array(eval_share_obs)

                eval_recurrent_hidden_states = np.zeros(
                    (1, num_agents, args.hidden_size)).astype(np.float32)
                eval_recurrent_hidden_states_critic = np.zeros(
                    (1, num_agents, args.hidden_size)).astype(np.float32)
                eval_masks = np.ones((1, num_agents, 1)).astype(np.float32)

                if eval_dones[0]:
                    eval_episode += 1
                    if "success" in eval_infos[0].keys():
                        if eval_infos[0]['success']:
                            eval_success += 1
                    for agent_id in range(num_agents):
                        eval_recurrent_hidden_states[0][agent_id] = np.zeros(
                            args.hidden_size).astype(np.float32)
                        eval_recurrent_hidden_states_critic[0][
                            agent_id] = np.zeros(args.hidden_size).astype(
                                np.float32)
                        eval_masks[0][agent_id] = 0.0
                else:
                    for agent_id in range(num_agents):
                        eval_masks[0][agent_id] = 1.0

                if eval_episode >= args.eval_episodes:
                    logger.add_scalars('eval_success_rate', {
                        'eval_success_rate':
                        eval_success / args.eval_episodes
                    }, total_num_steps)
                    print("eval_success_rate is " +
                          str(eval_success / args.eval_episodes))
                    break

    logger.export_scalars_to_json(str(log_dir / 'summary.json'))
    logger.close()
    envs.close()
    if args.eval:
        eval_env.close()
Пример #4
0
def main():
    args = get_config()

    assert (
        args.share_policy == True
        and args.scenario_name == 'simple_speaker_listener'
    ) == False, (
        "The simple_speaker_listener scenario can not use shared policy. Please check the config.py."
    )

    # seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    # cuda
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.set_num_threads(args.n_training_threads)
        if args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        device = torch.device("cpu")
        torch.set_num_threads(args.n_training_threads)

    # path
    model_dir = Path(
        './results') / args.env_name / args.scenario_name / args.algorithm_name
    if not model_dir.exists():
        curr_run = 'run1'
    else:
        exst_run_nums = [
            int(str(folder.name).split('run')[1])
            for folder in model_dir.iterdir()
            if str(folder.name).startswith('run')
        ]
        if len(exst_run_nums) == 0:
            curr_run = 'run1'
        else:
            curr_run = 'run%i' % (max(exst_run_nums) + 1)

    run_dir = model_dir / curr_run
    log_dir = run_dir / 'logs'
    save_dir = run_dir / 'models'
    os.makedirs(str(log_dir))
    os.makedirs(str(save_dir))
    logger = SummaryWriter(str(log_dir))

    # env
    envs = make_parallel_env(args)
    num_agents = args.num_agents

    #Policy network
    if args.share_policy:
        actor_critic = Policy(envs.observation_space[0],
                              envs.action_space[0],
                              num_agents=num_agents,
                              gain=args.gain,
                              base_kwargs={
                                  'naive_recurrent':
                                  args.naive_recurrent_policy,
                                  'recurrent': args.recurrent_policy,
                                  'hidden_size': args.hidden_size,
                                  'recurrent_N': args.recurrent_N,
                                  'attn': args.attn,
                                  'attn_only_critic': args.attn_only_critic,
                                  'attn_size': args.attn_size,
                                  'attn_N': args.attn_N,
                                  'attn_heads': args.attn_heads,
                                  'dropout': args.dropout,
                                  'use_average_pool': args.use_average_pool,
                                  'use_common_layer': args.use_common_layer,
                                  'use_feature_normlization':
                                  args.use_feature_normlization,
                                  'use_feature_popart':
                                  args.use_feature_popart,
                                  'use_orthogonal': args.use_orthogonal,
                                  'layer_N': args.layer_N,
                                  'use_ReLU': args.use_ReLU,
                                  'use_same_dim': args.use_same_dim
                              },
                              device=device)
        actor_critic.to(device)
        # algorithm
        agents = PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.data_chunk_length,
                     args.value_loss_coef,
                     args.entropy_coef,
                     logger,
                     lr=args.lr,
                     eps=args.eps,
                     weight_decay=args.weight_decay,
                     max_grad_norm=args.max_grad_norm,
                     use_max_grad_norm=args.use_max_grad_norm,
                     use_clipped_value_loss=args.use_clipped_value_loss,
                     use_common_layer=args.use_common_layer,
                     use_huber_loss=args.use_huber_loss,
                     huber_delta=args.huber_delta,
                     use_popart=args.use_popart,
                     use_value_high_masks=args.use_value_high_masks,
                     device=device)

        #replay buffer
        rollouts = RolloutStorage(num_agents, args.episode_length,
                                  args.n_rollout_threads,
                                  envs.observation_space[0],
                                  envs.action_space[0], args.hidden_size)
    else:
        actor_critic = []
        agents = []
        rollouts = []
        for agent_id in range(num_agents):
            ac = Policy(
                envs.observation_space,
                envs.action_space[agent_id],
                num_agents=agent_id,  # here is special
                gain=args.gain,
                base_kwargs={
                    'naive_recurrent': args.naive_recurrent_policy,
                    'recurrent': args.recurrent_policy,
                    'hidden_size': args.hidden_size,
                    'recurrent_N': args.recurrent_N,
                    'attn': args.attn,
                    'attn_only_critic': args.attn_only_critic,
                    'attn_size': args.attn_size,
                    'attn_N': args.attn_N,
                    'attn_heads': args.attn_heads,
                    'dropout': args.dropout,
                    'use_average_pool': args.use_average_pool,
                    'use_common_layer': args.use_common_layer,
                    'use_feature_normlization': args.use_feature_normlization,
                    'use_feature_popart': args.use_feature_popart,
                    'use_orthogonal': args.use_orthogonal,
                    'layer_N': args.layer_N,
                    'use_ReLU': args.use_ReLU,
                    'use_same_dim': args.use_same_dim
                },
                device=device)
            ac.to(device)
            # algorithm
            agent = PPO(ac,
                        args.clip_param,
                        args.ppo_epoch,
                        args.num_mini_batch,
                        args.data_chunk_length,
                        args.value_loss_coef,
                        args.entropy_coef,
                        logger,
                        lr=args.lr,
                        eps=args.eps,
                        weight_decay=args.weight_decay,
                        max_grad_norm=args.max_grad_norm,
                        use_max_grad_norm=args.use_max_grad_norm,
                        use_clipped_value_loss=args.use_clipped_value_loss,
                        use_common_layer=args.use_common_layer,
                        use_huber_loss=args.use_huber_loss,
                        huber_delta=args.huber_delta,
                        use_popart=args.use_popart,
                        use_value_high_masks=args.use_value_high_masks,
                        device=device)

            actor_critic.append(ac)
            agents.append(agent)

            #replay buffer
            ro = SingleRolloutStorage(agent_id, args.episode_length,
                                      args.n_rollout_threads,
                                      envs.observation_space,
                                      envs.action_space, args.hidden_size)
            rollouts.append(ro)

    # reset env
    obs, _ = envs.reset()

    # replay buffer
    if args.share_policy:
        share_obs = obs.reshape(args.n_rollout_threads, -1)
        share_obs = np.expand_dims(share_obs, 1).repeat(num_agents, axis=1)
        rollouts.share_obs[0] = share_obs.copy()
        rollouts.obs[0] = obs.copy()
        rollouts.recurrent_hidden_states = np.zeros(
            rollouts.recurrent_hidden_states.shape).astype(np.float32)
        rollouts.recurrent_hidden_states_critic = np.zeros(
            rollouts.recurrent_hidden_states_critic.shape).astype(np.float32)
    else:

        share_obs = []
        for o in obs:
            share_obs.append(list(itertools.chain(*o)))
        share_obs = np.array(share_obs)
        for agent_id in range(num_agents):
            rollouts[agent_id].share_obs[0] = share_obs.copy()
            rollouts[agent_id].obs[0] = np.array(list(obs[:, agent_id])).copy()
            rollouts[agent_id].recurrent_hidden_states = np.zeros(
                rollouts[agent_id].recurrent_hidden_states.shape).astype(
                    np.float32)
            rollouts[agent_id].recurrent_hidden_states_critic = np.zeros(
                rollouts[agent_id].recurrent_hidden_states_critic.shape
            ).astype(np.float32)

    # run
    start = time.time()
    episodes = int(
        args.num_env_steps) // args.episode_length // args.n_rollout_threads
    timesteps = 0

    for episode in range(episodes):
        if args.use_linear_lr_decay:  # decrease learning rate linearly
            if args.share_policy:
                update_linear_schedule(agents.optimizer, episode, episodes,
                                       args.lr)
            else:
                for agent_id in range(num_agents):
                    update_linear_schedule(agents[agent_id].optimizer, episode,
                                           episodes, args.lr)

        for step in range(args.episode_length):
            # Sample actions
            values = []
            actions = []
            action_log_probs = []
            recurrent_hidden_statess = []
            recurrent_hidden_statess_critic = []

            with torch.no_grad():
                for agent_id in range(num_agents):
                    if args.share_policy:
                        actor_critic.eval()
                        value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic.act(
                            agent_id,
                            torch.FloatTensor(rollouts.share_obs[step, :,
                                                                 agent_id]),
                            torch.FloatTensor(rollouts.obs[step, :, agent_id]),
                            torch.FloatTensor(
                                rollouts.recurrent_hidden_states[step, :,
                                                                 agent_id]),
                            torch.FloatTensor(
                                rollouts.recurrent_hidden_states_critic[
                                    step, :, agent_id]),
                            torch.FloatTensor(rollouts.masks[step, :,
                                                             agent_id]))
                    else:
                        actor_critic[agent_id].eval()
                        value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic[
                            agent_id].act(
                                agent_id,
                                torch.FloatTensor(
                                    rollouts[agent_id].share_obs[step, :]),
                                torch.FloatTensor(
                                    rollouts[agent_id].obs[step, :]),
                                torch.FloatTensor(
                                    rollouts[agent_id].recurrent_hidden_states[
                                        step, :]),
                                torch.FloatTensor(
                                    rollouts[agent_id].
                                    recurrent_hidden_states_critic[step, :]),
                                torch.FloatTensor(
                                    rollouts[agent_id].masks[step, :]))

                    values.append(value.detach().cpu().numpy())
                    actions.append(action.detach().cpu().numpy())
                    action_log_probs.append(
                        action_log_prob.detach().cpu().numpy())
                    recurrent_hidden_statess.append(
                        recurrent_hidden_states.detach().cpu().numpy())
                    recurrent_hidden_statess_critic.append(
                        recurrent_hidden_states_critic.detach().cpu().numpy())

            # rearrange action
            actions_env = []
            for i in range(args.n_rollout_threads):
                one_hot_action_env = []
                for agent_id in range(num_agents):
                    if envs.action_space[
                            agent_id].__class__.__name__ == 'MultiDiscrete':
                        uc_action = []
                        for j in range(envs.action_space[agent_id].shape):
                            uc_one_hot_action = np.zeros(
                                envs.action_space[agent_id].high[j] + 1)
                            uc_one_hot_action[actions[agent_id][i][j]] = 1
                            uc_action.append(uc_one_hot_action)
                        uc_action = np.concatenate(uc_action)
                        one_hot_action_env.append(uc_action)

                    elif envs.action_space[
                            agent_id].__class__.__name__ == 'Discrete':
                        one_hot_action = np.zeros(
                            envs.action_space[agent_id].n)
                        one_hot_action[actions[agent_id][i]] = 1
                        one_hot_action_env.append(one_hot_action)
                    else:
                        raise NotImplementedError
                actions_env.append(one_hot_action_env)

            # Obser reward and next obs
            obs, rewards, dones, infos, _ = envs.step(actions_env)

            # If done then clean the history of observations.
            # insert data in buffer
            masks = []
            for i, done in enumerate(dones):
                mask = []
                for agent_id in range(num_agents):
                    if done[agent_id]:
                        recurrent_hidden_statess[agent_id][i] = np.zeros(
                            args.hidden_size).astype(np.float32)
                        recurrent_hidden_statess_critic[agent_id][
                            i] = np.zeros(args.hidden_size).astype(np.float32)
                        mask.append([0.0])
                    else:
                        mask.append([1.0])
                masks.append(mask)

            if args.share_policy:
                share_obs = obs.reshape(args.n_rollout_threads, -1)
                share_obs = np.expand_dims(share_obs, 1).repeat(num_agents,
                                                                axis=1)

                rollouts.insert(
                    share_obs, obs,
                    np.array(recurrent_hidden_statess).transpose(1, 0, 2),
                    np.array(recurrent_hidden_statess_critic).transpose(
                        1, 0, 2),
                    np.array(actions).transpose(1, 0, 2),
                    np.array(action_log_probs).transpose(1, 0, 2),
                    np.array(values).transpose(1, 0, 2), rewards, masks)
            else:
                share_obs = []
                for o in obs:
                    share_obs.append(list(itertools.chain(*o)))
                share_obs = np.array(share_obs)
                for agent_id in range(num_agents):
                    rollouts[agent_id].insert(
                        share_obs, np.array(list(obs[:, agent_id])),
                        np.array(recurrent_hidden_statess[agent_id]),
                        np.array(recurrent_hidden_statess_critic[agent_id]),
                        np.array(actions[agent_id]),
                        np.array(action_log_probs[agent_id]),
                        np.array(values[agent_id]), rewards[:, agent_id],
                        np.array(masks)[:, agent_id])

        with torch.no_grad():
            for agent_id in range(num_agents):
                if args.share_policy:
                    actor_critic.eval()
                    next_value, _, _ = actor_critic.get_value(
                        agent_id,
                        torch.FloatTensor(rollouts.share_obs[-1, :, agent_id]),
                        torch.FloatTensor(rollouts.obs[-1, :, agent_id]),
                        torch.FloatTensor(
                            rollouts.recurrent_hidden_states[-1, :, agent_id]),
                        torch.FloatTensor(
                            rollouts.recurrent_hidden_states_critic[-1, :,
                                                                    agent_id]),
                        torch.FloatTensor(rollouts.masks[-1, :, agent_id]))
                    next_value = next_value.detach().cpu().numpy()
                    rollouts.compute_returns(agent_id, next_value,
                                             args.use_gae, args.gamma,
                                             args.gae_lambda,
                                             args.use_proper_time_limits,
                                             args.use_popart,
                                             agents.value_normalizer)
                else:
                    actor_critic[agent_id].eval()
                    next_value, _, _ = actor_critic[agent_id].get_value(
                        agent_id,
                        torch.FloatTensor(rollouts[agent_id].share_obs[-1, :]),
                        torch.FloatTensor(rollouts[agent_id].obs[-1, :]),
                        torch.FloatTensor(
                            rollouts[agent_id].recurrent_hidden_states[-1, :]),
                        torch.FloatTensor(
                            rollouts[agent_id].recurrent_hidden_states_critic[
                                -1, :]),
                        torch.FloatTensor(rollouts[agent_id].masks[-1, :]))
                    next_value = next_value.detach().cpu().numpy()
                    rollouts[agent_id].compute_returns(
                        next_value, args.use_gae, args.gamma, args.gae_lambda,
                        args.use_proper_time_limits, args.use_popart,
                        agents[agent_id].value_normalizer)

        # update the network
        if args.share_policy:
            actor_critic.train()
            value_loss, action_loss, dist_entropy = agents.update_share(
                num_agents, rollouts)

            for agent_id in range(num_agents):
                rew = []
                for i in range(rollouts.rewards.shape[1]):
                    rew.append(np.sum(rollouts.rewards[:, i, agent_id]))
                logger.add_scalars('agent%i/average_episode_reward' % agent_id,
                                   {'average_episode_reward': np.mean(rew)},
                                   (episode + 1) * args.episode_length *
                                   args.n_rollout_threads)
            # clean the buffer and reset
            rollouts.after_update()
        else:
            value_losses = []
            action_losses = []
            dist_entropies = []

            for agent_id in range(num_agents):
                actor_critic[agent_id].train()
                value_loss, action_loss, dist_entropy = agents[
                    agent_id].update_single(agent_id, rollouts[agent_id])
                value_losses.append(value_loss)
                action_losses.append(action_loss)
                dist_entropies.append(dist_entropy)

                rew = []
                for i in range(rollouts[agent_id].rewards.shape[1]):
                    rew.append(np.sum(rollouts[agent_id].rewards[:, i]))
                logger.add_scalars('agent%i/average_episode_reward' % agent_id,
                                   {'average_episode_reward': np.mean(rew)},
                                   (episode + 1) * args.episode_length *
                                   args.n_rollout_threads)

                rollouts[agent_id].after_update()

        total_num_steps = (episode +
                           1) * args.episode_length * args.n_rollout_threads

        if (episode % args.save_interval == 0 or episode == episodes -
                1):  # save for every interval-th episode or for the last epoch
            if args.share_policy:
                torch.save({'model': actor_critic},
                           str(save_dir) + "/agent_model.pt")
            else:
                for agent_id in range(num_agents):
                    torch.save({'model': actor_critic[agent_id]},
                               str(save_dir) + "/agent%i_model" % agent_id +
                               ".pt")

        # log information
        if episode % args.log_interval == 0:
            end = time.time()
            print(
                "\n Scenario {} Algo {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                .format(args.scenario_name, args.algorithm_name, episode,
                        episodes, total_num_steps, args.num_env_steps,
                        int(total_num_steps / (end - start))))
            if args.share_policy:
                print("value loss of agent: " + str(value_loss))
            else:
                for agent_id in range(num_agents):
                    print("value loss of agent%i: " % agent_id +
                          str(value_losses[agent_id]))

            if args.env_name == "MPE":
                for agent_id in range(num_agents):
                    show_rewards = []
                    for info in infos:
                        if 'individual_reward' in info[agent_id].keys():
                            show_rewards.append(
                                info[agent_id]['individual_reward'])
                    logger.add_scalars(
                        'agent%i/individual_reward' % agent_id,
                        {'individual_reward': np.mean(show_rewards)},
                        total_num_steps)

    logger.export_scalars_to_json(str(log_dir / 'summary.json'))
    logger.close()
    envs.close()
Пример #5
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.gpu and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # arguments
    LAMBDA = [1.0, 0.0, 1.0,
              10e-5]  # for [loss_dq, loss_n_dq, loss_jeq, loss_l2]
    CUDA_VISIBLE_DEVICES = 0
    seed = args.seed
    train = args.train
    demo = args.demo
    task = args.task
    iteration = 3
    convs = [(32, 7, 3), (64, 4, 2), (64, 3, 1)]
    non_pixel_layer = [64]
    in_feature = 7 * 7 * 64
    hidden_actions = [128]
    hidden_value = [128]
    aggregator = "reduceLocalMean"
    dtype = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.FloatTensor

    #if not train:
    #    args.num_env_steps = 50000

    base_kwargs = {
        'non_pixel_layer': non_pixel_layer,
        'convs': convs,
        'frame_history_len': args.frame_history_len,
        'in_feature': in_feature,
        'hidden_actions': hidden_actions,
        'hidden_value': hidden_value,
        'aggregator': aggregator
    }

    # logger
    logging.basicConfig(level=logging.INFO)

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)

    # threads and device
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.gpu else "cpu")
    print("device:", device)
    gpu = args.gpu
    if (gpu == True):
        print("current available gpu numbers: %d" % torch.cuda.device_count())
        if torch.cuda.is_available():
            torch.cuda.set_device(CUDA_VISIBLE_DEVICES)
            print("CUDA Device: %d" % torch.cuda.current_device())

    # envs

    #envs = gym.make(task)
    #obs_space = env.observation_space
    #act_space = env.action_space
    #action_template = env.action_space.noop()
    env = gym.make(args.task)
    obs_space = env.observation_space
    act_space = env.action_space
    action_template = env.action_space.noop()

    # policy
    actor_critic = Policy(obs_space, act_space, base_kwargs=base_kwargs)
    actor_critic.to(device)

    # algorithm
    if args.algo == 'ppo':
        agent = PPO(
            actor_critic,
            args.clip_param,
            args.ppo_epoch,
            args.num_mini_batch,
            args.value_loss_coef,
            args.entropy_coef,
            lr=7e-4,
            eps=1e-5,
            max_grad_norm=args.max_grad_norm,
        )
    else:
        raise NotImplementedError

    # storage
    replay_buffer = None
    if args.frame_history_len > 1:
        _, _, non_pixel_shape = parse_obs_space(obs_space)
        add_non_pixel = True if non_pixel_shape > 0 else False
        replay_buffer = ReplayBuffer(100000, args.frame_history_len,
                                     non_pixel_shape, add_non_pixel)

    rollouts = RolloutStorage(replay_buffer, args.frame_history_len,
                              args.num_steps, args.num_processes, obs_space,
                              act_space)

    obs = env.reset()
    #print("reset obs pov size: ",obs['pov'].shape)
    # obs: key: inventory.dirt...
    # (num_processes, size)

    pov, non_pixel_feature = get_obs_features(obs_space, obs)
    #pov, non_pixel_feature = multi_get_obs_features(obs)
    if args.frame_history_len > 1:
        last_stored_frame_idx = replay_buffer.store_frame(
            pov, non_pixel_feature)
        pov = replay_buffer.encode_recent_observation() / 255.0  # 12 h w
        pov = torch.from_numpy(pov.copy()).reshape(args.num_processes,
                                                   *pov.shape)
    elif args.frame_history_len == 1:
        pov = pov.transpose(2, 0, 1) / 255.0
        pov = torch.from_numpy(pov.copy()).reshape(args.num_processes,
                                                   *pov.shape)
    else:
        raise NotImplementedError

    non_pixel_feature = (torch.tensor(non_pixel_feature) / 180.0).reshape(
        args.num_processes, -1)

    rollouts.obs[0].copy_(pov)
    rollouts.non_pixel_obs[0].copy_(non_pixel_feature)
    rollouts.to(device)

    # ?
    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    print("Total steps: ", args.num_env_steps)

    ep = 0
    ep_rewards = []
    #mean_episode_reward = -float('nan')
    best_mean_episode_reward = -float('inf')
    #total_rewards = [0 for i in range(args.num_processes)]
    total_rewards = 0

    for j in range(num_updates):

        for step in range(args.num_steps):
            # num_steps = 5
            # Sample actions
            with torch.no_grad():
                # actor_critic.act output size
                # actions: torch.Tensor, not list
                value, actions, action_log_probs = actor_critic.act(
                    rollouts.obs[step], rollouts.non_pixel_obs[step])

            # value size: batch x 1
            # actions size: torch.Tensor num_processes x num_branches
            # action_log_probs : torch.Tensor num_processes x num_branches
            #print(actions)
            actions_list = actions.squeeze().tolist()

            action = get_actions_continuous(actions_list, act_space,
                                            action_template)

            # step:
            #print(actions)
            obs, reward, done, infos = env.step(action)
            #print('.',end='')
            if args.num_env_steps <= 50000:
                env.render()

            pov, non_pixel_feature = get_obs_features(obs_space, obs)
            #pov, non_pixel_feature = multi_get_obs_features(obs)
            if args.frame_history_len > 1:
                last_stored_frame_idx = replay_buffer.store_frame(
                    pov, non_pixel_feature)
                pov = replay_buffer.encode_recent_observation(
                ) / 255.0  # 12 h w
                pov = torch.from_numpy(pov.copy()).reshape(
                    args.num_processes, *pov.shape)
            elif args.frame_history_len == 1:
                pov = pov.transpose(2, 0, 1) / 255.0
                pov = torch.from_numpy(pov.copy()).reshape(
                    args.num_processes, *pov.shape)
            else:
                raise NotImplementedError

            non_pixel_feature = (torch.tensor(non_pixel_feature) /
                                 180.0).reshape(args.num_processes, -1)

            total_rewards += reward
            #for i in range(len(reward)):
            #    total_rewards[i] += reward[i]
            reward = torch.tensor([reward]).reshape(args.num_processes,
                                                    -1).type(dtype)

            # TODO: may not need bas_masks
            masks = torch.FloatTensor([[0.0] if done else [1.0]])
            bad_masks = torch.FloatTensor([[1.0]])

            if done:
                ep += 1
                ep_rewards.append(total_rewards)
                best_mean_episode_reward = log(j, args.task, ep,
                                               np.array(ep_rewards),
                                               best_mean_episode_reward)

                obs = env.reset()
                pov, non_pixel_feature = get_obs_features(obs_space, obs)
                #pov, non_pixel_feature = multi_get_obs_features(obs)
                if args.frame_history_len > 1:
                    last_stored_frame_idx = replay_buffer.store_frame(
                        pov, non_pixel_feature)
                    pov = replay_buffer.encode_recent_observation(
                    ) / 255.0  # 12 h w
                    pov = torch.from_numpy(pov.copy()).reshape(
                        args.num_processes, *pov.shape)
                elif args.frame_history_len == 1:
                    pov = pov.transpose(2, 0, 1) / 255.0
                    pov = torch.from_numpy(pov.copy()).reshape(
                        args.num_processes, *pov.shape)
                else:
                    raise NotImplementedError
                non_pixel_feature = (torch.tensor(non_pixel_feature) /
                                     180.0).reshape(args.num_processes, -1)

                total_rewards = 0
            # ?
            rollouts.insert(pov, non_pixel_feature, actions, action_log_probs,
                            value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1],
                                                rollouts.non_pixel_obs[-1])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        # TODO: minibathc = 32, 1 processes x 10 step should larger than 32
        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_model_dir != '':
            save_path = os.path.join(args.save_model_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save(actor_critic, os.path.join(save_path,
                                                  args.task + ".pt"))

        if j % args.log_interval == 0 and len(ep_rewards) >= 0:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print("----------- Logs -------------")
            if len(ep_rewards) == 0:
                print(
                    "Updates {}, num timesteps {}, FPS {} \nThe {}th training episodes,"
                    .format(j, total_num_steps,
                            int(total_num_steps / (end - start)),
                            len(ep_rewards)))
            else:
                print(
                    "Updates {}, num timesteps {}, FPS {} \nThe {}th training episodes,\nmean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                    .format(j, total_num_steps,
                            int(total_num_steps / (end - start)),
                            len(ep_rewards), np.mean(ep_rewards),
                            np.median(ep_rewards), np.min(ep_rewards),
                            np.max(ep_rewards)))

    print("-----------------------Training ends-----------------------")
    env.close()
Пример #6
0
def main():
    args = get_config()
    run = wandb.init(project='curriculum',name=str(args.algorithm_name) + "_seed" + str(args.seed))
    # run = wandb.init(project='check',name='separate_reward')
    
    assert (args.share_policy == True and args.scenario_name == 'simple_speaker_listener') == False, ("The simple_speaker_listener scenario can not use shared policy. Please check the config.py.")

    # seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    # cuda
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.set_num_threads(1)
        if args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        device = torch.device("cpu")
        torch.set_num_threads(args.n_training_threads)
    
    # path
    model_dir = Path('./results') / args.env_name / args.scenario_name / args.algorithm_name
    node_dir = Path('./node') / args.env_name / args.scenario_name / args.algorithm_name
    if not model_dir.exists():
        curr_run = 'run1'
    else:
        exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in model_dir.iterdir() if str(folder.name).startswith('run')]
        if len(exst_run_nums) == 0:
            curr_run = 'run1'
        else:
            curr_run = 'run%i' % (max(exst_run_nums) + 1)
    if not node_dir.exists():
        node_curr_run = 'run1'
    else:
        exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in node_dir.iterdir() if str(folder.name).startswith('run')]
        if len(exst_run_nums) == 0:
            node_curr_run = 'run1'
        else:
            node_curr_run = 'run%i' % (max(exst_run_nums) + 1)

    run_dir = model_dir / curr_run
    save_node_dir = node_dir / node_curr_run
    log_dir = run_dir / 'logs'
    save_dir = run_dir / 'models'
    os.makedirs(str(log_dir))
    os.makedirs(str(save_dir))
    logger = SummaryWriter(str(log_dir)) 

    # env
    envs = make_parallel_env(args)
    num_agents = args.num_agents
    #Policy network
    if args.share_policy:
        actor_base = ATTBase_actor_dist_add(envs.observation_space[0].shape[0], envs.action_space[0], num_agents)
        critic_base = ATTBase_critic_add(envs.observation_space[0].shape[0], num_agents)
        actor_critic = Policy3(envs.observation_space[0], 
                    envs.action_space[0],
                    num_agents = num_agents,
                    base=None,
                    actor_base=actor_base,
                    critic_base=critic_base,
                    base_kwargs={'naive_recurrent': args.naive_recurrent_policy,
                                 'recurrent': args.recurrent_policy,
                                 'hidden_size': args.hidden_size,
                                 'attn': args.attn,                                 
                                 'attn_size': args.attn_size,
                                 'attn_N': args.attn_N,
                                 'attn_heads': args.attn_heads,
                                 'dropout': args.dropout,
                                 'use_average_pool': args.use_average_pool,
                                 'use_common_layer':args.use_common_layer,
                                 'use_feature_normlization':args.use_feature_normlization,
                                 'use_feature_popart':args.use_feature_popart,
                                 'use_orthogonal':args.use_orthogonal,
                                 'layer_N':args.layer_N,
                                 'use_ReLU':args.use_ReLU
                                 },
                    device = device)
        actor_critic.to(device)
        # algorithm
        agents = PPO3(actor_critic,
                   args.clip_param,
                   args.ppo_epoch,
                   args.num_mini_batch,
                   args.data_chunk_length,
                   args.value_loss_coef,
                   args.entropy_coef,
                   logger,
                   lr=args.lr,
                   eps=args.eps,
                   weight_decay=args.weight_decay,
                   max_grad_norm=args.max_grad_norm,
                   use_max_grad_norm=args.use_max_grad_norm,
                   use_clipped_value_loss= args.use_clipped_value_loss,
                   use_common_layer=args.use_common_layer,
                   use_huber_loss=args.use_huber_loss,
                   huber_delta=args.huber_delta,
                   use_popart=args.use_popart,
                   device=device)
                   
        # #replay buffer
        # rollouts = RolloutStorage(num_agents,
        #             args.episode_length, 
        #             args.n_rollout_threads,
        #             envs.observation_space[0], 
        #             envs.action_space[0],
        #             args.hidden_size)        
    else:
        actor_critic = []
        agents = []
        rollouts = []
        for agent_id in range(num_agents):
            ac = Policy(envs.observation_space, 
                      envs.action_space[agent_id],
                      num_agents = agent_id, # here is special
                      base_kwargs={'naive_recurrent': args.naive_recurrent_policy,
                                 'recurrent': args.recurrent_policy,
                                 'hidden_size': args.hidden_size,
                                 'attn': args.attn,                                 
                                 'attn_size': args.attn_size,
                                 'attn_N': args.attn_N,
                                 'attn_heads': args.attn_heads,
                                 'dropout': args.dropout,
                                 'use_average_pool': args.use_average_pool,
                                 'use_common_layer':args.use_common_layer,
                                 'use_feature_normlization':args.use_feature_normlization,
                                 'use_feature_popart':args.use_feature_popart,
                                 'use_orthogonal':args.use_orthogonal,
                                 'layer_N':args.layer_N,
                                 'use_ReLU':args.use_ReLU
                                 },
                      device = device)
            ac.to(device)
            # algorithm
            agent = PPO(ac,
                   args.clip_param,
                   args.ppo_epoch,
                   args.num_mini_batch,
                   args.data_chunk_length,
                   args.value_loss_coef,
                   args.entropy_coef,
                   logger,
                   lr=args.lr,
                   eps=args.eps,
                   weight_decay=args.weight_decay,
                   max_grad_norm=args.max_grad_norm,
                   use_max_grad_norm=args.use_max_grad_norm,
                   use_clipped_value_loss= args.use_clipped_value_loss,
                   use_common_layer=args.use_common_layer,
                   use_huber_loss=args.use_huber_loss,
                   huber_delta=args.huber_delta,
                   use_popart=args.use_popart,
                   device=device)
                               
            actor_critic.append(ac)
            agents.append(agent) 
              
            #replay buffer
            ro = SingleRolloutStorage(agent_id,
                    args.episode_length, 
                    args.n_rollout_threads,
                    envs.observation_space, 
                    envs.action_space,
                    args.hidden_size)
            rollouts.append(ro)
    
   
    boundary = 3
    start_boundary = [-0.3,0.3,-0.3,0.3] # 分别代表x的范围和y的范围
    # start_boundary = [2.4,3.0,2.4,3.0]
    max_step = 0.6
    N_easy = 0
    test_flag = 0
    reproduce_flag = 0
    target_num = 4
    last_agent_num = 4
    now_agent_num = num_agents
    mean_cover_rate = 0
    eval_frequency = 2 #需要fix几个回合
    check_frequency = 1
    save_node_frequency = 5
    save_node_flag = True
    save_90_flag = True
    historical_length = 5
    random.seed(args.seed)
    np.random.seed(args.seed)


    # init the Gan
    gan_configs['goal_range'] = boundary
    gan_configs['goal_center'] = np.zeros((num_agents + num_agents)* 2, dtype=float)
    gan_configs['goal_size'] = (num_agents + num_agents)*2
    gan = StateGAN(gan_configs = gan_configs, state_range=gan_configs['goal_range'])
    feasible_goals = generate_initial_goals(num_case = 10000, start_boundary = start_boundary, agent_num = args.num_agents)                            
    dis_loss, gen_loss = gan.pretrain(states=feasible_goals, outer_iters=gan_configs['gan_outer_iters'])
    print('discriminator_loss:',str(dis_loss.cpu()), 'generator_loss:',str(gen_loss.cpu()))
    
    # init the StateCollection
    all_goals = StateCollection(distance_threshold=goal_configs['coll_eps'])

    # run
    begin = time.time()
    episodes = int(args.num_env_steps) // args.episode_length // args.n_rollout_threads // eval_frequency
    curriculum_episode = 0
    current_timestep = 0
    one_length = args.n_rollout_threads
    starts_length = args.n_rollout_threads
    num_envs = args.n_rollout_threads

    for episode in range(episodes):
        if args.use_linear_lr_decay:# decrease learning rate linearly
            if args.share_policy:   
                update_linear_schedule(agents.optimizer, episode, episodes, args.lr)  
            else:     
                for agent_id in range(num_agents):
                    update_linear_schedule(agents[agent_id].optimizer, episode, episodes, args.lr)           



        raw_goals, _ = gan.sample_states_with_noise(goal_configs['num_new_goals'])
        # replay buffer
        if all_goals.size > 0:
            old_goals = all_goals.sample(goal_configs['num_old_goals'])
            goals = np.vstack([raw_goals, old_goals])
        else:
            goals = raw_goals   
        if goals.shape[0] < num_envs:
            add_num = num_envs - goals.shape[0]
            goals = np.vstack([goals, goals[:add_num]]) #补齐到num_new_goals+num_old_goals   
        # generate the starts
        starts = numpy_to_list(goals, list_length=num_envs, shape=(num_agents*2,2))

        for times in range(eval_frequency):
            obs = envs.new_starts_obs(starts, now_agent_num, starts_length)
            #replay buffer
            rollouts = RolloutStorage(num_agents,
                        args.episode_length, 
                        starts_length,
                        envs.observation_space[0], 
                        envs.action_space[0],
                        args.hidden_size) 
            # replay buffer init
            if args.share_policy: 
                share_obs = obs.reshape(starts_length, -1)        
                share_obs = np.expand_dims(share_obs,1).repeat(num_agents,axis=1)    
                rollouts.share_obs[0] = share_obs.copy() 
                rollouts.obs[0] = obs.copy()               
                rollouts.recurrent_hidden_states = np.zeros(rollouts.recurrent_hidden_states.shape).astype(np.float32)
                rollouts.recurrent_hidden_states_critic = np.zeros(rollouts.recurrent_hidden_states_critic.shape).astype(np.float32)
            else:
                share_obs = []
                for o in obs:
                    share_obs.append(list(itertools.chain(*o)))
                share_obs = np.array(share_obs)
                for agent_id in range(num_agents):    
                    rollouts[agent_id].share_obs[0] = share_obs.copy()
                    rollouts[agent_id].obs[0] = np.array(list(obs[:,agent_id])).copy()               
                    rollouts[agent_id].recurrent_hidden_states = np.zeros(rollouts[agent_id].recurrent_hidden_states.shape).astype(np.float32)
                    rollouts[agent_id].recurrent_hidden_states_critic = np.zeros(rollouts[agent_id].recurrent_hidden_states_critic.shape).astype(np.float32)
            step_cover_rate = np.zeros(shape=(one_length,args.episode_length))
            for step in range(args.episode_length):
                # Sample actions
                values = []
                actions= []
                action_log_probs = []
                recurrent_hidden_statess = []
                recurrent_hidden_statess_critic = []
                
                with torch.no_grad():                
                    for agent_id in range(num_agents):
                        if args.share_policy:
                            actor_critic.eval()
                            value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic.act(agent_id,
                                torch.FloatTensor(rollouts.share_obs[step,:,agent_id]), 
                                torch.FloatTensor(rollouts.obs[step,:,agent_id]), 
                                torch.FloatTensor(rollouts.recurrent_hidden_states[step,:,agent_id]), 
                                torch.FloatTensor(rollouts.recurrent_hidden_states_critic[step,:,agent_id]),
                                torch.FloatTensor(rollouts.masks[step,:,agent_id]))
                        else:
                            actor_critic[agent_id].eval()
                            value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic[agent_id].act(agent_id,
                                torch.FloatTensor(rollouts[agent_id].share_obs[step,:]), 
                                torch.FloatTensor(rollouts[agent_id].obs[step,:]), 
                                torch.FloatTensor(rollouts[agent_id].recurrent_hidden_states[step,:]), 
                                torch.FloatTensor(rollouts[agent_id].recurrent_hidden_states_critic[step,:]),
                                torch.FloatTensor(rollouts[agent_id].masks[step,:]))
                            
                        values.append(value.detach().cpu().numpy())
                        actions.append(action.detach().cpu().numpy())
                        action_log_probs.append(action_log_prob.detach().cpu().numpy())
                        recurrent_hidden_statess.append(recurrent_hidden_states.detach().cpu().numpy())
                        recurrent_hidden_statess_critic.append(recurrent_hidden_states_critic.detach().cpu().numpy())
                # rearrange action
                actions_env = []
                for i in range(starts_length):
                    one_hot_action_env = []
                    for agent_id in range(num_agents):
                        if envs.action_space[agent_id].__class__.__name__ == 'MultiDiscrete':
                            uc_action = []
                            for j in range(envs.action_space[agent_id].shape):
                                uc_one_hot_action = np.zeros(envs.action_space[agent_id].high[j]+1)
                                uc_one_hot_action[actions[agent_id][i][j]] = 1
                                uc_action.append(uc_one_hot_action)
                            uc_action = np.concatenate(uc_action)
                            one_hot_action_env.append(uc_action)
                                
                        elif envs.action_space[agent_id].__class__.__name__ == 'Discrete':    
                            one_hot_action = np.zeros(envs.action_space[agent_id].n)
                            one_hot_action[actions[agent_id][i]] = 1
                            one_hot_action_env.append(one_hot_action)
                        else:
                            raise NotImplementedError
                    actions_env.append(one_hot_action_env)
                
                # Obser reward and next obs
                obs, rewards, dones, infos, _ = envs.step(actions_env, starts_length, num_agents)
                cover_rate_list = []
                for env_id in range(one_length):
                    cover_rate_list.append(infos[env_id][0]['cover_rate'])
                step_cover_rate[:,step] = np.array(cover_rate_list)
                # step_cover_rate[:,step] = np.array(infos)[0:one_length,0]

                # If done then clean the history of observations.
                # insert data in buffer
                masks = []
                for i, done in enumerate(dones): 
                    mask = []               
                    for agent_id in range(num_agents): 
                        if done[agent_id]:    
                            recurrent_hidden_statess[agent_id][i] = np.zeros(args.hidden_size).astype(np.float32)
                            recurrent_hidden_statess_critic[agent_id][i] = np.zeros(args.hidden_size).astype(np.float32)    
                            mask.append([0.0])
                        else:
                            mask.append([1.0])
                    masks.append(mask)
                                
                if args.share_policy: 
                    share_obs = obs.reshape(starts_length, -1)        
                    share_obs = np.expand_dims(share_obs,1).repeat(num_agents,axis=1)    
                    rollouts.insert(share_obs, 
                                obs, 
                                np.array(recurrent_hidden_statess).transpose(1,0,2), 
                                np.array(recurrent_hidden_statess_critic).transpose(1,0,2), 
                                np.array(actions).transpose(1,0,2),
                                np.array(action_log_probs).transpose(1,0,2), 
                                np.array(values).transpose(1,0,2),
                                rewards, 
                                masks)
                else:
                    share_obs = []
                    for o in obs:
                        share_obs.append(list(itertools.chain(*o)))
                    share_obs = np.array(share_obs)
                    for agent_id in range(num_agents):
                        rollouts[agent_id].insert(share_obs, 
                                np.array(list(obs[:,agent_id])), 
                                np.array(recurrent_hidden_statess[agent_id]), 
                                np.array(recurrent_hidden_statess_critic[agent_id]), 
                                np.array(actions[agent_id]),
                                np.array(action_log_probs[agent_id]), 
                                np.array(values[agent_id]),
                                rewards[:,agent_id], 
                                np.array(masks)[:,agent_id])
            # logger.add_scalars('agent/training_cover_rate',{'training_cover_rate': np.mean(np.mean(step_cover_rate[:,-historical_length:],axis=1))}, current_timestep)
            wandb.log({'training_cover_rate': np.mean(np.mean(step_cover_rate[:,-historical_length:],axis=1))}, current_timestep)
            print('training_cover_rate: ', np.mean(np.mean(step_cover_rate[:,-historical_length:],axis=1)))
            current_timestep += args.episode_length * starts_length
            curriculum_episode += 1
            
            #region train the gan

            
            if times == 1:
                start_time = time.time()
                filtered_raw_goals = []
                labels = np.zeros((num_envs, 1), dtype = int)
                for i in range(num_envs):
                    R_i = np.mean(step_cover_rate[i, -goal_configs['historical_length']:])
                    if R_i < goal_configs['R_max'] and R_i > goal_configs['R_min']:
                        labels[i] = 1
                        filtered_raw_goals.append(goals[i])
                gan.train(goals, labels)
                all_goals.append(filtered_raw_goals)
                end_time = time.time()
                print("Gan training time: %.2f"%(end_time-start_time))

            with torch.no_grad():  # get value and com
                for agent_id in range(num_agents):         
                    if args.share_policy: 
                        actor_critic.eval()                
                        next_value,_,_ = actor_critic.get_value(agent_id,
                                                    torch.FloatTensor(rollouts.share_obs[-1,:,agent_id]), 
                                                    torch.FloatTensor(rollouts.obs[-1,:,agent_id]), 
                                                    torch.FloatTensor(rollouts.recurrent_hidden_states[-1,:,agent_id]),
                                                    torch.FloatTensor(rollouts.recurrent_hidden_states_critic[-1,:,agent_id]),
                                                    torch.FloatTensor(rollouts.masks[-1,:,agent_id]))
                        next_value = next_value.detach().cpu().numpy()
                        rollouts.compute_returns(agent_id,
                                        next_value, 
                                        args.use_gae, 
                                        args.gamma,
                                        args.gae_lambda, 
                                        args.use_proper_time_limits,
                                        args.use_popart,
                                        agents.value_normalizer)
                    else:
                        actor_critic[agent_id].eval()
                        next_value,_,_ = actor_critic[agent_id].get_value(agent_id,
                                                                torch.FloatTensor(rollouts[agent_id].share_obs[-1,:]), 
                                                                torch.FloatTensor(rollouts[agent_id].obs[-1,:]), 
                                                                torch.FloatTensor(rollouts[agent_id].recurrent_hidden_states[-1,:]),
                                                                torch.FloatTensor(rollouts[agent_id].recurrent_hidden_states_critic[-1,:]),
                                                                torch.FloatTensor(rollouts[agent_id].masks[-1,:]))
                        next_value = next_value.detach().cpu().numpy()
                        rollouts[agent_id].compute_returns(next_value, 
                                                args.use_gae, 
                                                args.gamma,
                                                args.gae_lambda, 
                                                args.use_proper_time_limits,
                                                args.use_popart,
                                                agents[agent_id].value_normalizer)

            # update the network
            if args.share_policy:
                actor_critic.train()
                value_loss, action_loss, dist_entropy = agents.update_share_asynchronous(last_agent_num, rollouts, False, initial_optimizer=False) 
                print('value_loss: ', value_loss)
                wandb.log(
                    {'value_loss': value_loss},
                    current_timestep)
                rew = []
                for i in range(rollouts.rewards.shape[1]):
                    rew.append(np.sum(rollouts.rewards[:,i]))
                wandb.log(
                    {'average_episode_reward': np.mean(rew)},
                    current_timestep)
                # clean the buffer and reset
                rollouts.after_update()
            else:
                value_losses = []
                action_losses = []
                dist_entropies = [] 
                
                for agent_id in range(num_agents):
                    actor_critic[agent_id].train()
                    value_loss, action_loss, dist_entropy = agents[agent_id].update_single(agent_id, rollouts[agent_id])
                    value_losses.append(value_loss)
                    action_losses.append(action_loss)
                    dist_entropies.append(dist_entropy)
                        
                    rew = []
                    for i in range(rollouts[agent_id].rewards.shape[1]):
                        rew.append(np.sum(rollouts[agent_id].rewards[:,i]))

                    logger.add_scalars('agent%i/average_episode_reward'%agent_id,
                        {'average_episode_reward': np.mean(rew)},
                        (episode+1) * args.episode_length * one_length*eval_frequency)
                    
                    rollouts[agent_id].after_update()


        # test
        if episode % check_frequency==0:
            obs, _ = envs.reset(num_agents)
            episode_length = 70
            #replay buffer
            rollouts = RolloutStorage(num_agents,
                        episode_length, 
                        args.n_rollout_threads,
                        envs.observation_space[0], 
                        envs.action_space[0],
                        args.hidden_size) 
            # replay buffer init
            if args.share_policy: 
                share_obs = obs.reshape(args.n_rollout_threads, -1)        
                share_obs = np.expand_dims(share_obs,1).repeat(num_agents,axis=1)    
                rollouts.share_obs[0] = share_obs.copy() 
                rollouts.obs[0] = obs.copy()               
                rollouts.recurrent_hidden_states = np.zeros(rollouts.recurrent_hidden_states.shape).astype(np.float32)
                rollouts.recurrent_hidden_states_critic = np.zeros(rollouts.recurrent_hidden_states_critic.shape).astype(np.float32)
            else:
                share_obs = []
                for o in obs:
                    share_obs.append(list(itertools.chain(*o)))
                share_obs = np.array(share_obs)
                for agent_id in range(num_agents):    
                    rollouts[agent_id].share_obs[0] = share_obs.copy()
                    rollouts[agent_id].obs[0] = np.array(list(obs[:,agent_id])).copy()               
                    rollouts[agent_id].recurrent_hidden_states = np.zeros(rollouts[agent_id].recurrent_hidden_states.shape).astype(np.float32)
                    rollouts[agent_id].recurrent_hidden_states_critic = np.zeros(rollouts[agent_id].recurrent_hidden_states_critic.shape).astype(np.float32)
            test_cover_rate = np.zeros(shape=(args.n_rollout_threads,episode_length))
            for step in range(episode_length):
                # Sample actions
                values = []
                actions= []
                action_log_probs = []
                recurrent_hidden_statess = []
                recurrent_hidden_statess_critic = []
                
                with torch.no_grad():                
                    for agent_id in range(num_agents):
                        if args.share_policy:
                            actor_critic.eval()
                            value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic.act(agent_id,
                                torch.FloatTensor(rollouts.share_obs[step,:,agent_id]), 
                                torch.FloatTensor(rollouts.obs[step,:,agent_id]), 
                                torch.FloatTensor(rollouts.recurrent_hidden_states[step,:,agent_id]), 
                                torch.FloatTensor(rollouts.recurrent_hidden_states_critic[step,:,agent_id]),
                                torch.FloatTensor(rollouts.masks[step,:,agent_id]),deterministic=True)
                        else:
                            actor_critic[agent_id].eval()
                            value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic[agent_id].act(agent_id,
                                torch.FloatTensor(rollouts[agent_id].share_obs[step,:]), 
                                torch.FloatTensor(rollouts[agent_id].obs[step,:]), 
                                torch.FloatTensor(rollouts[agent_id].recurrent_hidden_states[step,:]), 
                                torch.FloatTensor(rollouts[agent_id].recurrent_hidden_states_critic[step,:]),
                                torch.FloatTensor(rollouts[agent_id].masks[step,:]),deterministic=True)
                            
                        values.append(value.detach().cpu().numpy())
                        actions.append(action.detach().cpu().numpy())
                        action_log_probs.append(action_log_prob.detach().cpu().numpy())
                        recurrent_hidden_statess.append(recurrent_hidden_states.detach().cpu().numpy())
                        recurrent_hidden_statess_critic.append(recurrent_hidden_states_critic.detach().cpu().numpy())
                
                # rearrange action
                actions_env = []
                for i in range(args.n_rollout_threads):
                    one_hot_action_env = []
                    for agent_id in range(num_agents):
                        if envs.action_space[agent_id].__class__.__name__ == 'MultiDiscrete':
                            uc_action = []
                            for j in range(envs.action_space[agent_id].shape):
                                uc_one_hot_action = np.zeros(envs.action_space[agent_id].high[j]+1)
                                uc_one_hot_action[actions[agent_id][i][j]] = 1
                                uc_action.append(uc_one_hot_action)
                            uc_action = np.concatenate(uc_action)
                            one_hot_action_env.append(uc_action)
                                
                        elif envs.action_space[agent_id].__class__.__name__ == 'Discrete':    
                            one_hot_action = np.zeros(envs.action_space[agent_id].n)
                            one_hot_action[actions[agent_id][i]] = 1
                            one_hot_action_env.append(one_hot_action)
                        else:
                            raise NotImplementedError
                    actions_env.append(one_hot_action_env)
                
                # Obser reward and next obs
                obs, rewards, dones, infos, _ = envs.step(actions_env, args.n_rollout_threads, num_agents)
                cover_rate_list = []
                for env_id in range(args.n_rollout_threads):
                    cover_rate_list.append(infos[env_id][0]['cover_rate'])
                test_cover_rate[:,step] = np.array(cover_rate_list)
                # test_cover_rate[:,step] = np.array(infos)[:,0]

                # If done then clean the history of observations.
                # insert data in buffer
                masks = []
                for i, done in enumerate(dones): 
                    mask = []               
                    for agent_id in range(num_agents): 
                        if done[agent_id]:    
                            recurrent_hidden_statess[agent_id][i] = np.zeros(args.hidden_size).astype(np.float32)
                            recurrent_hidden_statess_critic[agent_id][i] = np.zeros(args.hidden_size).astype(np.float32)    
                            mask.append([0.0])
                        else:
                            mask.append([1.0])
                    masks.append(mask)
                                
                if args.share_policy: 
                    share_obs = obs.reshape(args.n_rollout_threads, -1)        
                    share_obs = np.expand_dims(share_obs,1).repeat(num_agents,axis=1)    
                    
                    rollouts.insert(share_obs, 
                                obs, 
                                np.array(recurrent_hidden_statess).transpose(1,0,2), 
                                np.array(recurrent_hidden_statess_critic).transpose(1,0,2), 
                                np.array(actions).transpose(1,0,2),
                                np.array(action_log_probs).transpose(1,0,2), 
                                np.array(values).transpose(1,0,2),
                                rewards, 
                                masks)
                else:
                    share_obs = []
                    for o in obs:
                        share_obs.append(list(itertools.chain(*o)))
                    share_obs = np.array(share_obs)
                    for agent_id in range(num_agents):
                        rollouts[agent_id].insert(share_obs, 
                                np.array(list(obs[:,agent_id])), 
                                np.array(recurrent_hidden_statess[agent_id]), 
                                np.array(recurrent_hidden_statess_critic[agent_id]), 
                                np.array(actions[agent_id]),
                                np.array(action_log_probs[agent_id]), 
                                np.array(values[agent_id]),
                                rewards[:,agent_id], 
                                np.array(masks)[:,agent_id])

            # logger.add_scalars('agent/cover_rate_1step',{'cover_rate_1step': np.mean(test_cover_rate[:,-1])},current_timestep)
            # logger.add_scalars('agent/cover_rate_5step',{'cover_rate_5step': np.mean(np.mean(test_cover_rate[:,-historical_length:],axis=1))}, current_timestep)
            rew = []
            for i in range(rollouts.rewards.shape[1]):
                rew.append(np.sum(rollouts.rewards[:,i]))
            wandb.log(
                {'eval_episode_reward': np.mean(rew)},
                current_timestep)
            wandb.log({'cover_rate_1step': np.mean(test_cover_rate[:,-1])},current_timestep)
            wandb.log({'cover_rate_5step': np.mean(np.mean(test_cover_rate[:,-historical_length:],axis=1))}, current_timestep)
            mean_cover_rate = np.mean(np.mean(test_cover_rate[:,-historical_length:],axis=1))
            if mean_cover_rate >= 0.9 and args.algorithm_name=='ours' and save_90_flag:
                torch.save({'model': actor_critic}, str(save_dir) + "/cover09_agent_model.pt")
                save_90_flag = False
            print('test_agent_num: ', last_agent_num)
            print('test_mean_cover_rate: ', mean_cover_rate)

        total_num_steps = current_timestep

        if (episode % args.save_interval == 0 or episode == episodes - 1):# save for every interval-th episode or for the last epoch
            if args.share_policy:
                torch.save({
                            'model': actor_critic
                            }, 
                            str(save_dir) + "/agent_model.pt")
            else:
                for agent_id in range(num_agents):                                                  
                    torch.save({
                                'model': actor_critic[agent_id]
                                }, 
                                str(save_dir) + "/agent%i_model" % agent_id + ".pt")

        # log information
        if episode % args.log_interval == 0:
            end = time.time()
            print("\n Scenario {} Algo {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                .format(args.scenario_name,
                        args.algorithm_name,
                        episode, 
                        episodes,
                        total_num_steps,
                        args.num_env_steps,
                        int(total_num_steps / (end - begin))))
            if args.share_policy:
                print("value loss of agent: " + str(value_loss))
            else:
                for agent_id in range(num_agents):
                    print("value loss of agent%i: " % agent_id + str(value_losses[agent_id]))
                
    logger.export_scalars_to_json(str(log_dir / 'summary.json'))
    logger.close()
    envs.close()
Пример #7
0
def main():
    args = get_config()

    # seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # cuda
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.set_num_threads(args.n_training_threads)
        if args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        device = torch.device("cpu")
        torch.set_num_threads(args.n_training_threads)

    # path
    model_dir = Path(
        './results') / args.env_name / args.map_name / args.algorithm_name
    if not model_dir.exists():
        curr_run = 'run1'
    else:
        exst_run_nums = [
            int(str(folder.name).split('run')[1])
            for folder in model_dir.iterdir()
            if str(folder.name).startswith('run')
        ]
        if len(exst_run_nums) == 0:
            curr_run = 'run1'
        else:
            curr_run = 'run%i' % (max(exst_run_nums) + 1)

    run_dir = model_dir / curr_run
    log_dir = run_dir / 'logs'
    save_dir = run_dir / 'models'
    os.makedirs(str(log_dir))
    os.makedirs(str(save_dir))
    logger = SummaryWriter(str(log_dir))

    # env
    envs = make_parallel_env(args)
    if args.eval:
        eval_env = make_eval_env(args)
    num_agents = get_map_params(args.map_name)["n_agents"]
    #Policy network

    if args.share_policy:
        actor_critic = Policy(envs.observation_space[0],
                              envs.action_space[0],
                              num_agents=num_agents,
                              gain=args.gain,
                              base_kwargs={
                                  'naive_recurrent':
                                  args.naive_recurrent_policy,
                                  'recurrent': args.recurrent_policy,
                                  'hidden_size': args.hidden_size,
                                  'recurrent_N': args.recurrent_N,
                                  'attn': args.attn,
                                  'attn_only_critic': args.attn_only_critic,
                                  'attn_size': args.attn_size,
                                  'attn_N': args.attn_N,
                                  'attn_heads': args.attn_heads,
                                  'dropout': args.dropout,
                                  'use_average_pool': args.use_average_pool,
                                  'use_common_layer': args.use_common_layer,
                                  'use_feature_normlization':
                                  args.use_feature_normlization,
                                  'use_feature_popart':
                                  args.use_feature_popart,
                                  'use_orthogonal': args.use_orthogonal,
                                  'layer_N': args.layer_N,
                                  'use_ReLU': args.use_ReLU,
                                  'use_same_dim': args.use_same_dim
                              },
                              device=device)
        actor_critic.to(device)
        # algorithm
        agents = PPO(actor_critic,
                     args.clip_param,
                     args.ppo_epoch,
                     args.num_mini_batch,
                     args.data_chunk_length,
                     args.value_loss_coef,
                     args.entropy_coef,
                     logger,
                     lr=args.lr,
                     eps=args.eps,
                     weight_decay=args.weight_decay,
                     max_grad_norm=args.max_grad_norm,
                     use_max_grad_norm=args.use_max_grad_norm,
                     use_clipped_value_loss=args.use_clipped_value_loss,
                     use_common_layer=args.use_common_layer,
                     use_huber_loss=args.use_huber_loss,
                     huber_delta=args.huber_delta,
                     use_popart=args.use_popart,
                     use_value_high_masks=args.use_value_high_masks,
                     device=device)

        #replay buffer
        rollouts = RolloutStorage(num_agents, args.episode_length,
                                  args.n_rollout_threads,
                                  envs.observation_space[0],
                                  envs.action_space[0], args.hidden_size)
    else:
        actor_critic = []
        agents = []
        for agent_id in range(num_agents):
            ac = Policy(envs.observation_space[0],
                        envs.action_space[0],
                        num_agents=num_agents,
                        gain=args.gain,
                        base_kwargs={
                            'naive_recurrent': args.naive_recurrent_policy,
                            'recurrent': args.recurrent_policy,
                            'hidden_size': args.hidden_size,
                            'recurrent_N': args.recurrent_N,
                            'attn': args.attn,
                            'attn_only_critic': args.attn_only_critic,
                            'attn_size': args.attn_size,
                            'attn_N': args.attn_N,
                            'attn_heads': args.attn_heads,
                            'dropout': args.dropout,
                            'use_average_pool': args.use_average_pool,
                            'use_common_layer': args.use_common_layer,
                            'use_feature_normlization':
                            args.use_feature_normlization,
                            'use_feature_popart': args.use_feature_popart,
                            'use_orthogonal': args.use_orthogonal,
                            'layer_N': args.layer_N,
                            'use_ReLU': args.use_ReLU,
                            'use_same_dim': args.use_same_dim
                        },
                        device=device)
            ac.to(device)
            # algorithm
            agent = PPO(ac,
                        args.clip_param,
                        args.ppo_epoch,
                        args.num_mini_batch,
                        args.data_chunk_length,
                        args.value_loss_coef,
                        args.entropy_coef,
                        logger,
                        lr=args.lr,
                        eps=args.eps,
                        weight_decay=args.weight_decay,
                        max_grad_norm=args.max_grad_norm,
                        use_max_grad_norm=args.use_max_grad_norm,
                        use_clipped_value_loss=args.use_clipped_value_loss,
                        use_common_layer=args.use_common_layer,
                        use_huber_loss=args.use_huber_loss,
                        huber_delta=args.huber_delta,
                        use_popart=args.use_popart,
                        use_value_high_masks=args.use_value_high_masks,
                        device=device)

            actor_critic.append(ac)
            agents.append(agent)

        #replay buffer
        rollouts = RolloutStorage(num_agents, args.episode_length,
                                  args.n_rollout_threads,
                                  envs.observation_space[0],
                                  envs.action_space[0], args.hidden_size)

    # reset env
    obs, available_actions = envs.reset()

    # replay buffer
    if len(envs.observation_space[0]) == 3:
        share_obs = obs.reshape(args.n_rollout_threads, -1,
                                envs.observation_space[0][1],
                                envs.observation_space[0][2])
    else:
        share_obs = obs.reshape(args.n_rollout_threads, -1)

    share_obs = np.expand_dims(share_obs, 1).repeat(num_agents, axis=1)
    rollouts.share_obs[0] = share_obs.copy()
    rollouts.obs[0] = obs.copy()
    rollouts.available_actions[0] = available_actions.copy()
    rollouts.recurrent_hidden_states = np.zeros(
        rollouts.recurrent_hidden_states.shape).astype(np.float32)
    rollouts.recurrent_hidden_states_critic = np.zeros(
        rollouts.recurrent_hidden_states_critic.shape).astype(np.float32)

    # run
    start = time.time()
    episodes = int(
        args.num_env_steps) // args.episode_length // args.n_rollout_threads
    timesteps = 0
    last_battles_game = np.zeros(args.n_rollout_threads)
    last_battles_won = np.zeros(args.n_rollout_threads)

    for episode in range(episodes):
        if args.use_linear_lr_decay:  # decrease learning rate linearly
            if args.share_policy:
                update_linear_schedule(agents.optimizer, episode, episodes,
                                       args.lr)
            else:
                for agent_id in range(num_agents):
                    update_linear_schedule(agents[agent_id].optimizer, episode,
                                           episodes, args.lr)

        for step in range(args.episode_length):
            # Sample actions
            values = []
            actions = []
            action_log_probs = []
            recurrent_hidden_statess = []
            recurrent_hidden_statess_critic = []

            with torch.no_grad():
                for agent_id in range(num_agents):
                    if args.share_policy:
                        actor_critic.eval()
                        value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic.act(
                            agent_id,
                            torch.tensor(rollouts.share_obs[step, :,
                                                            agent_id]),
                            torch.tensor(rollouts.obs[step, :, agent_id]),
                            torch.tensor(
                                rollouts.recurrent_hidden_states[step, :,
                                                                 agent_id]),
                            torch.tensor(
                                rollouts.recurrent_hidden_states_critic[
                                    step, :, agent_id]),
                            torch.tensor(rollouts.masks[step, :, agent_id]),
                            torch.tensor(rollouts.available_actions[step, :,
                                                                    agent_id]))
                    else:
                        actor_critic[agent_id].eval()
                        value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic[
                            agent_id].act(
                                agent_id,
                                torch.tensor(rollouts.share_obs[step, :,
                                                                agent_id]),
                                torch.tensor(rollouts.obs[step, :, agent_id]),
                                torch.tensor(rollouts.recurrent_hidden_states[
                                    step, :, agent_id]),
                                torch.tensor(
                                    rollouts.recurrent_hidden_states_critic[
                                        step, :, agent_id]),
                                torch.tensor(rollouts.masks[step, :,
                                                            agent_id]),
                                torch.tensor(
                                    rollouts.available_actions[step, :,
                                                               agent_id]))

                    values.append(value.detach().cpu().numpy())
                    actions.append(action.detach().cpu().numpy())
                    action_log_probs.append(
                        action_log_prob.detach().cpu().numpy())
                    recurrent_hidden_statess.append(
                        recurrent_hidden_states.detach().cpu().numpy())
                    recurrent_hidden_statess_critic.append(
                        recurrent_hidden_states_critic.detach().cpu().numpy())

            # rearrange action
            actions_env = []
            for i in range(args.n_rollout_threads):
                one_hot_action_env = []
                for agent_id in range(num_agents):
                    one_hot_action = np.zeros(envs.action_space[agent_id].n)
                    one_hot_action[actions[agent_id][i]] = 1
                    one_hot_action_env.append(one_hot_action)
                actions_env.append(one_hot_action_env)

            # Obser reward and next obs
            obs, reward, dones, infos, available_actions = envs.step(
                actions_env)

            # If done then clean the history of observations.
            # insert data in buffer
            masks = []
            for i, done in enumerate(dones):
                mask = []
                for agent_id in range(num_agents):
                    if done:
                        recurrent_hidden_statess[agent_id][i] = np.zeros(
                            args.hidden_size).astype(np.float32)
                        recurrent_hidden_statess_critic[agent_id][
                            i] = np.zeros(args.hidden_size).astype(np.float32)
                        mask.append([0.0])
                    else:
                        mask.append([1.0])
                masks.append(mask)

            bad_masks = []
            high_masks = []
            for info in infos:
                bad_mask = []
                high_mask = []
                for agent_id in range(num_agents):
                    if info[agent_id]['bad_transition']:
                        bad_mask.append([0.0])
                    else:
                        bad_mask.append([1.0])

                    if info[agent_id]['high_masks']:
                        high_mask.append([1.0])
                    else:
                        high_mask.append([0.0])
                bad_masks.append(bad_mask)
                high_masks.append(high_mask)

            if len(envs.observation_space[0]) == 3:
                share_obs = obs.reshape(args.n_rollout_threads, -1,
                                        envs.observation_space[0][1],
                                        envs.observation_space[0][2])
                share_obs = np.expand_dims(share_obs, 1).repeat(num_agents,
                                                                axis=1)

                rollouts.insert(
                    share_obs, obs,
                    np.array(recurrent_hidden_statess).transpose(1, 0, 2),
                    np.array(recurrent_hidden_statess_critic).transpose(
                        1, 0, 2),
                    np.array(actions).transpose(1, 0, 2),
                    np.array(action_log_probs).transpose(1, 0, 2),
                    np.array(values).transpose(1, 0, 2), reward, masks,
                    bad_masks, high_masks, available_actions)
            else:
                share_obs = obs.reshape(args.n_rollout_threads, -1)
                share_obs = np.expand_dims(share_obs, 1).repeat(num_agents,
                                                                axis=1)

                rollouts.insert(
                    share_obs, obs,
                    np.array(recurrent_hidden_statess).transpose(1, 0, 2),
                    np.array(recurrent_hidden_statess_critic).transpose(
                        1, 0, 2),
                    np.array(actions).transpose(1, 0, 2),
                    np.array(action_log_probs).transpose(1, 0, 2),
                    np.array(values).transpose(1, 0, 2), reward, masks,
                    bad_masks, high_masks, available_actions)

        with torch.no_grad():
            for agent_id in range(num_agents):
                if args.share_policy:
                    actor_critic.eval()
                    next_value, _, _ = actor_critic.get_value(
                        agent_id,
                        torch.tensor(rollouts.share_obs[-1, :, agent_id]),
                        torch.tensor(rollouts.obs[-1, :, agent_id]),
                        torch.tensor(
                            rollouts.recurrent_hidden_states[-1, :, agent_id]),
                        torch.tensor(
                            rollouts.recurrent_hidden_states_critic[-1, :,
                                                                    agent_id]),
                        torch.tensor(rollouts.masks[-1, :, agent_id]))
                    next_value = next_value.detach().cpu().numpy()
                    rollouts.compute_returns(agent_id, next_value,
                                             args.use_gae, args.gamma,
                                             args.gae_lambda,
                                             args.use_proper_time_limits,
                                             args.use_popart,
                                             agents.value_normalizer)
                else:
                    actor_critic[agent_id].eval()
                    next_value, _, _ = actor_critic[agent_id].get_value(
                        agent_id,
                        torch.tensor(rollouts.share_obs[-1, :, agent_id]),
                        torch.tensor(rollouts.obs[-1, :, agent_id]),
                        torch.tensor(
                            rollouts.recurrent_hidden_states[-1, :, agent_id]),
                        torch.tensor(
                            rollouts.recurrent_hidden_states_critic[-1, :,
                                                                    agent_id]),
                        torch.tensor(rollouts.masks[-1, :, agent_id]))
                    next_value = next_value.detach().cpu().numpy()
                    rollouts.compute_returns(agent_id, next_value,
                                             args.use_gae, args.gamma,
                                             args.gae_lambda,
                                             args.use_proper_time_limits,
                                             args.use_popart,
                                             agents[agent_id].value_normalizer)

        # update the network
        if args.share_policy:
            actor_critic.train()
            value_loss, action_loss, dist_entropy = agents.update_share(
                num_agents, rollouts)

            logger.add_scalars('reward', {'reward': np.mean(rollouts.rewards)},
                               (episode + 1) * args.episode_length *
                               args.n_rollout_threads)
        else:
            value_losses = []
            action_losses = []
            dist_entropies = []

            for agent_id in range(num_agents):
                actor_critic[agent_id].train()
                value_loss, action_loss, dist_entropy = agents[
                    agent_id].update(agent_id, rollouts)
                value_losses.append(value_loss)
                action_losses.append(action_loss)
                dist_entropies.append(dist_entropy)

                logger.add_scalars(
                    'agent%i/reward' % agent_id,
                    {'reward': np.mean(rollouts.rewards[:, :, agent_id])},
                    (episode + 1) * args.episode_length *
                    args.n_rollout_threads)

        # clean the buffer and reset
        rollouts.after_update()

        total_num_steps = (episode +
                           1) * args.episode_length * args.n_rollout_threads

        if (episode % args.save_interval == 0 or episode == episodes -
                1):  # save for every interval-th episode or for the last epoch
            if args.share_policy:
                torch.save({'model': actor_critic},
                           str(save_dir) + "/agent_model.pt")
            else:
                for agent_id in range(num_agents):
                    torch.save({'model': actor_critic[agent_id]},
                               str(save_dir) + "/agent%i_model" % agent_id +
                               ".pt")

        # log information
        if episode % args.log_interval == 0:
            end = time.time()
            print(
                "\n Map {} Algo {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                .format(args.map_name, args.algorithm_name, episode, episodes,
                        total_num_steps, args.num_env_steps,
                        int(total_num_steps / (end - start))))
            if args.share_policy:
                print("value loss of agent: " + str(value_loss))
            else:
                for agent_id in range(num_agents):
                    print("value loss of agent%i: " % agent_id +
                          str(value_losses[agent_id]))

            if args.env_name == "StarCraft2":
                battles_won = []
                battles_game = []
                incre_battles_won = []
                incre_battles_game = []

                for i, info in enumerate(infos):
                    if 'battles_won' in info[0].keys():
                        battles_won.append(info[0]['battles_won'])
                        incre_battles_won.append(info[0]['battles_won'] -
                                                 last_battles_won[i])
                    if 'battles_game' in info[0].keys():
                        battles_game.append(info[0]['battles_game'])
                        incre_battles_game.append(info[0]['battles_game'] -
                                                  last_battles_game[i])

                if np.sum(incre_battles_game) > 0:
                    logger.add_scalars(
                        'incre_win_rate', {
                            'incre_win_rate':
                            np.sum(incre_battles_won) /
                            np.sum(incre_battles_game)
                        }, total_num_steps)
                else:
                    logger.add_scalars('incre_win_rate', {'incre_win_rate': 0},
                                       total_num_steps)
                last_battles_game = battles_game
                last_battles_won = battles_won

        if episode % args.eval_interval == 0 and args.eval:
            eval_battles_won = 0
            eval_episode = 0
            eval_obs, eval_available_actions = eval_env.reset()
            eval_share_obs = eval_obs.reshape(1, -1)
            eval_recurrent_hidden_states = np.zeros(
                (1, num_agents, args.hidden_size)).astype(np.float32)
            eval_recurrent_hidden_states_critic = np.zeros(
                (1, num_agents, args.hidden_size)).astype(np.float32)
            eval_masks = np.ones((1, num_agents, 1)).astype(np.float32)

            while True:
                eval_actions = []
                for agent_id in range(num_agents):
                    if args.share_policy:
                        actor_critic.eval()
                        _, action, _, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic.act(
                            agent_id,
                            torch.tensor(eval_share_obs),
                            torch.tensor(eval_obs[:, agent_id]),
                            torch.tensor(
                                eval_recurrent_hidden_states[:, agent_id]),
                            torch.tensor(
                                eval_recurrent_hidden_states_critic[:,
                                                                    agent_id]),
                            torch.tensor(eval_masks[:, agent_id]),
                            torch.tensor(eval_available_actions[:,
                                                                agent_id, :]),
                            deterministic=True)
                    else:
                        actor_critic[agent_id].eval()
                        _, action, _, recurrent_hidden_states, recurrent_hidden_states_critic = actor_critic[
                            agent_id].act(
                                agent_id,
                                torch.tensor(eval_share_obs),
                                torch.tensor(eval_obs[:, agent_id]),
                                torch.tensor(
                                    eval_recurrent_hidden_states[:, agent_id]),
                                torch.tensor(
                                    eval_recurrent_hidden_states_critic[:,
                                                                        agent_id]
                                ),
                                torch.tensor(eval_masks[:, agent_id]),
                                torch.tensor(
                                    eval_available_actions[:, agent_id, :]),
                                deterministic=True)

                    eval_actions.append(action.detach().cpu().numpy())
                    eval_recurrent_hidden_states[:,
                                                 agent_id] = recurrent_hidden_states.detach(
                                                 ).cpu().numpy()
                    eval_recurrent_hidden_states_critic[:,
                                                        agent_id] = recurrent_hidden_states_critic.detach(
                                                        ).cpu().numpy()

                # rearrange action
                eval_actions_env = []
                for agent_id in range(num_agents):
                    one_hot_action = np.zeros(
                        eval_env.action_space[agent_id].n)
                    one_hot_action[eval_actions[agent_id][0]] = 1
                    eval_actions_env.append(one_hot_action)

                # Obser reward and next obs
                eval_obs, eval_rewards, eval_dones, eval_infos, eval_available_actions = eval_env.step(
                    [eval_actions_env])
                eval_share_obs = eval_obs.reshape(1, -1)

                if eval_dones[0]:
                    eval_episode += 1
                    if eval_infos[0][0]['won']:
                        eval_battles_won += 1
                    for agent_id in range(num_agents):
                        eval_recurrent_hidden_states[0][agent_id] = np.zeros(
                            args.hidden_size).astype(np.float32)
                        eval_recurrent_hidden_states_critic[0][
                            agent_id] = np.zeros(args.hidden_size).astype(
                                np.float32)
                        eval_masks[0][agent_id] = 0.0
                else:
                    for agent_id in range(num_agents):
                        eval_masks[0][agent_id] = 1.0

                if eval_episode >= args.eval_episodes:
                    logger.add_scalars(
                        'eval_win_rate',
                        {'eval_win_rate': eval_battles_won / eval_episode},
                        total_num_steps)
                    break

    logger.export_scalars_to_json(str(log_dir / 'summary.json'))
    logger.close()
    envs.close()
    if args.eval:
        eval_env.close()
Пример #8
0
def main():
    args = get_config()

    # cuda
    if args.cuda and torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.set_num_threads(args.n_training_threads)
        if args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
    else:
        device = torch.device("cpu")
        torch.set_num_threads(args.n_training_threads)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    args.reward_randomization = False

    # path
    model_dir = Path('./results') / args.env_name / args.algorithm_name / (
        "run" + str(args.seed))

    run_dir = model_dir / 'finetune'
    log_dir = run_dir / 'logs'
    save_dir = run_dir / 'models'
    os.makedirs(str(log_dir))
    os.makedirs(str(save_dir))
    logger = SummaryWriter(str(log_dir))

    # env
    envs = make_parallel_env(args)
    #Policy network
    actor_critic = []
    if args.share_policy:
        ac = torch.load(str(model_dir / 'models') +
                        "/agent0_model.pt")['model'].to(device)
        for i in range(args.num_agents):
            actor_critic.append(ac)
    else:
        for i in range(args.num_agents):
            ac = torch.load(
                str(model_dir / 'models') + "/agent%i_model" % i +
                ".pt")['model'].to(device)
            actor_critic.append(ac)
    agents = []
    rollouts = []
    for agent_id in range(args.num_agents):
        # algorithm
        agent = PPO(actor_critic[agent_id],
                    agent_id,
                    args.clip_param,
                    args.ppo_epoch,
                    args.num_mini_batch,
                    args.data_chunk_length,
                    args.value_loss_coef,
                    args.entropy_coef,
                    logger,
                    lr=args.lr,
                    eps=args.eps,
                    max_grad_norm=args.max_grad_norm,
                    use_clipped_value_loss=args.use_clipped_value_loss)

        #replay buffer
        ro = RolloutStorage(args.num_agents, agent_id, args.episode_length,
                            args.n_rollout_threads,
                            envs.observation_space[agent_id],
                            envs.action_space[agent_id],
                            actor_critic[agent_id].recurrent_hidden_state_size)

        agents.append(agent)
        rollouts.append(ro)

    # reset env
    obs = envs.reset()
    # rollout
    for i in range(args.num_agents):
        if len(envs.observation_space[0]) == 1:
            rollouts[i].share_obs[0].copy_(
                torch.tensor(obs.reshape(args.n_rollout_threads, -1)))
            rollouts[i].obs[0].copy_(torch.tensor(obs[:, i, :]))
            rollouts[i].recurrent_hidden_states.zero_()
            rollouts[i].recurrent_hidden_states_critic.zero_()
            rollouts[i].recurrent_c_states.zero_()
            rollouts[i].recurrent_c_states_critic.zero_()
        else:
            raise NotImplementedError
        rollouts[i].to(device)

    # run
    coop_num = []
    defect_num = []
    coopdefect_num = []
    defectcoop_num = []
    gore1_num = []
    gore2_num = []
    collective_return = []
    apple_consumption = []
    waste_cleared = []
    sustainability = []
    fire = []

    start = time.time()
    episodes = int(
        args.num_env_steps) // args.episode_length // args.n_rollout_threads
    all_episode = 0
    turn_on = False

    for episode in range(episodes):
        if (episode > episodes / 2) and (turn_on == False):
            print("\n Turn off fixed actor...")
            # actor
            turn_on = True

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            for i in range(args.num_agents):
                update_linear_schedule(agents[i].optimizer, episode, episodes,
                                       args.lr)

        for step in range(args.episode_length):
            # Sample actions
            values = []
            actions = []
            action_log_probs = []
            recurrent_hidden_statess = []
            recurrent_hidden_statess_critic = []
            recurrent_c_statess = []
            recurrent_c_statess_critic = []

            with torch.no_grad():
                for i in range(args.num_agents):
                    value, action, action_log_prob, recurrent_hidden_states, recurrent_hidden_states_critic, recurrent_c_states, recurrent_c_states_critic = actor_critic[
                        i].act(
                            rollouts[i].share_obs[step], rollouts[i].obs[step],
                            rollouts[i].recurrent_hidden_states[step],
                            rollouts[i].recurrent_hidden_states_critic[step],
                            rollouts[i].recurrent_c_states[step],
                            rollouts[i].recurrent_c_states_critic[step],
                            rollouts[i].masks[step])
                    values.append(value)
                    actions.append(action)
                    action_log_probs.append(action_log_prob)
                    recurrent_hidden_statess.append(recurrent_hidden_states)
                    recurrent_hidden_statess_critic.append(
                        recurrent_hidden_states_critic)
                    recurrent_c_statess.append(recurrent_c_states)
                    recurrent_c_statess_critic.append(
                        recurrent_c_states_critic)

            # rearrange action
            actions_env = []
            for i in range(args.n_rollout_threads):
                one_hot_action_env = []
                for k in range(args.num_agents):
                    one_hot_action = np.zeros(envs.action_space[0].n)
                    one_hot_action[actions[k][i]] = 1
                    one_hot_action_env.append(one_hot_action)
                actions_env.append(one_hot_action_env)

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(actions_env)

            # If done then clean the history of observations.
            # insert data in buffer
            masks = []
            bad_masks = []
            masks_critic = []
            bad_masks_critic = []
            for i in range(args.num_agents):
                mask = []
                bad_mask = []
                for done_ in done:
                    if done_[i]:
                        mask.append([0.0])
                        bad_mask.append([1.0])
                    else:
                        mask.append([1.0])
                        bad_mask.append([1.0])
                masks.append(torch.FloatTensor(mask))
                bad_masks.append(torch.FloatTensor(bad_mask))

            for i in range(args.num_agents):
                if len(envs.observation_space[0]) == 1:
                    rollouts[i].insert(
                        torch.tensor(obs.reshape(args.n_rollout_threads, -1)),
                        torch.tensor(obs[:,
                                         i, :]), recurrent_hidden_statess[i],
                        recurrent_hidden_statess_critic[i],
                        recurrent_c_statess[i], recurrent_c_statess_critic[i],
                        actions[i], action_log_probs[i], values[i],
                        torch.tensor(reward[:, i].reshape(-1, 1)), masks[i],
                        bad_masks[i])
                else:
                    raise NotImplementedError

        with torch.no_grad():
            next_values = []
            for i in range(args.num_agents):
                next_value = actor_critic[i].get_value(
                    rollouts[i].share_obs[-1], rollouts[i].obs[-1],
                    rollouts[i].recurrent_hidden_states[-1],
                    rollouts[i].recurrent_hidden_states_critic[-1],
                    rollouts[i].recurrent_c_states[-1],
                    rollouts[i].recurrent_c_states_critic[-1],
                    rollouts[i].masks[-1]).detach()
                next_values.append(next_value)

        for i in range(args.num_agents):
            rollouts[i].compute_returns(next_values[i], args.use_gae,
                                        args.gamma, args.gae_lambda,
                                        args.use_proper_time_limits)

        # update the network
        value_losses = []
        action_losses = []
        dist_entropies = []
        for i in range(args.num_agents):
            value_loss, action_loss, dist_entropy = agents[i].update(
                rollouts[i], turn_on)
            value_losses.append(value_loss)
            action_losses.append(action_loss)
            dist_entropies.append(dist_entropy)

        if args.env_name == "StagHunt":
            for info in infos:
                if 'coop&coop_num' in info.keys():
                    coop_num.append(info['coop&coop_num'])
                if 'defect&defect_num' in info.keys():
                    defect_num.append(info['defect&defect_num'])
                if 'coop&defect_num' in info.keys():
                    coopdefect_num.append(info['coop&defect_num'])
                if 'defect&coop_num' in info.keys():
                    defectcoop_num.append(info['defect&coop_num'])

            for i in range(args.n_rollout_threads):
                logger.add_scalars(
                    'coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'defect&defect_num_per_episode',
                    {'defect&defect_num_per_episode': defect_num[all_episode]},
                    all_episode)
                logger.add_scalars('coop&defect_num_per_episode', {
                    'coop&defect_num_per_episode':
                    coopdefect_num[all_episode]
                }, all_episode)
                logger.add_scalars('defect&coop_num_per_episode', {
                    'defect&coop_num_per_episode':
                    defectcoop_num[all_episode]
                }, all_episode)
                all_episode += 1
        elif args.env_name == "StagHuntGW":
            for info in infos:
                if 'collective_return' in info.keys():
                    collective_return.append(info['collective_return'])
                if 'coop&coop_num' in info.keys():
                    coop_num.append(info['coop&coop_num'])
                if 'gore1_num' in info.keys():
                    gore1_num.append(info['gore1_num'])
                if 'gore2_num' in info.keys():
                    gore2_num.append(info['gore2_num'])

            for i in range(args.n_rollout_threads):
                logger.add_scalars(
                    'collective_return',
                    {'collective_return': collective_return[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'coop&coop_num_per_episode',
                    {'coop&coop_num_per_episode': coop_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'gore1_num_per_episode',
                    {'gore1_num_per_episode': gore1_num[all_episode]},
                    all_episode)
                logger.add_scalars(
                    'gore2_num_per_episode',
                    {'gore2_num_per_episode': gore2_num[all_episode]},
                    all_episode)

                all_episode += 1

        # clean the buffer and reset
        obs = envs.reset()
        for i in range(args.num_agents):
            if len(envs.observation_space[0]) == 1:
                rollouts[i].share_obs[0].copy_(
                    torch.tensor(obs.reshape(args.n_rollout_threads, -1)))
                rollouts[i].obs[0].copy_(torch.tensor(obs[:, i, :]))
                rollouts[i].recurrent_hidden_states.zero_()
                rollouts[i].recurrent_hidden_states_critic.zero_()
                rollouts[i].recurrent_c_states.zero_()
                rollouts[i].recurrent_c_states_critic.zero_()
            else:
                raise NotImplementedError
            rollouts[i].to(device)

        for i in range(args.num_agents):
            # save for every interval-th episode or for the last epoch
            if (episode % args.save_interval == 0 or episode == episodes - 1):
                torch.save({'model': actor_critic[i]},
                           str(save_dir) + "/agent%i_model" % i + ".pt")

        # log information
        if episode % args.log_interval == 0:
            total_num_steps = (
                episode + 1) * args.episode_length * args.n_rollout_threads
            end = time.time()
            print(
                "\n updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                .format(episode, episodes, total_num_steps, args.num_env_steps,
                        int(total_num_steps / (end - start))))
            for i in range(args.num_agents):
                print("value loss of agent%i: " % i + str(value_losses[i]))
    logger.export_scalars_to_json(str(log_dir / 'summary.json'))
    logger.close()