Example #1
0
def cim_dqn_actor():
    env = Env(**training_config["env"])
    agent = MultiAgentWrapper(
        {name: get_dqn_agent()
         for name in env.agent_idx_list})
    actor = Actor(env,
                  agent,
                  CIMTrajectoryForDQN,
                  trajectory_kwargs=common_config)
    actor.as_worker(training_config["group"], log_dir=log_dir)
Example #2
0
def cim_dqn_learner():
    env = Env(**training_config["env"])
    agent = MultiAgentWrapper(
        {name: get_dqn_agent()
         for name in env.agent_idx_list})
    scheduler = TwoPhaseLinearParameterScheduler(
        training_config["max_episode"], **training_config["exploration"])
    actor = ActorProxy(
        training_config["group"],
        training_config["num_actors"],
        update_trigger=training_config["learner_update_trigger"])
    learner = OffPolicyLearner(actor, scheduler, agent,
                               **training_config["training"])
    learner.run()
Example #3
0
            agent_id = list(state.keys())[0]
            data = training_data.setdefault(agent_id,
                                            {"args": [[] for _ in range(4)]})
            data["args"][0].append(state[agent_id])  # state
            data["args"][1].append(action[agent_id][0])  # action
            data["args"][2].append(action[agent_id][1])  # log_p
            data["args"][3].append(self.get_offline_reward(event))  # reward

        for agent_id in training_data:
            training_data[agent_id]["args"] = [
                np.asarray(vals, dtype=np.float32 if i == 3 else None)
                for i, vals in enumerate(training_data[agent_id]["args"])
            ]

        return training_data


# Single-threaded launcher
if __name__ == "__main__":
    set_seeds(1024)  # for reproducibility
    env = Env(**training_config["env"])
    agent = MultiAgentWrapper(
        {name: get_ac_agent()
         for name in env.agent_idx_list})
    actor = Actor(env,
                  agent,
                  CIMTrajectoryForAC,
                  trajectory_kwargs=common_config)  # local actor
    learner = OnPolicyLearner(actor, training_config["max_episode"])
    learner.run()