コード例 #1
0
 def __init__(self, state_size, action_size, n_agents, env):
     self.state_size = state_size
     self.action_size = action_size
     self.memory = []
     self.loss = 0
     self.deadlock_avoidance_policy = DeadLockAvoidanceAgent(
         env, action_size, False)
     self.ppo_policy = PPOPolicy(state_size + action_size, action_size)
    def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
        print(">> DeadLockAvoidanceWithDecisionAgent")
        super(DeadLockAvoidanceWithDecisionAgent, self).__init__()
        self.env = env
        self.state_size = state_size
        self.action_size = action_size
        self.learning_agent = learning_agent
        self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(
            self.env, action_size, False)
        self.policy_selector = PPOPolicy(state_size, 2)

        self.memory = self.learning_agent.memory
        self.loss = self.learning_agent.loss
コード例 #3
0
    def __init__(self, state_size, action_size, in_parameters=None):
        print(">> MultiDecisionAgent")
        super(MultiDecisionAgent, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.in_parameters = in_parameters
        self.memory = DummyMemory()
        self.loss = 0

        self.ppo_policy = PPOPolicy(state_size,
                                    action_size,
                                    use_replay_buffer=False,
                                    in_parameters=in_parameters)
        self.dddqn_policy = DDDQNPolicy(state_size, action_size, in_parameters)
        self.policy_selector = PPOPolicy(state_size, 2)
コード例 #4
0
class MultiPolicy(Policy):
    def __init__(self, state_size, action_size, n_agents, env):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = []
        self.loss = 0
        self.deadlock_avoidance_policy = DeadLockAvoidanceAgent(
            env, action_size, False)
        self.ppo_policy = PPOPolicy(state_size + action_size, action_size)

    def load(self, filename):
        self.ppo_policy.load(filename)
        self.deadlock_avoidance_policy.load(filename)

    def save(self, filename):
        self.ppo_policy.save(filename)
        self.deadlock_avoidance_policy.save(filename)

    def step(self, handle, state, action, reward, next_state, done):
        action_extra_state = self.deadlock_avoidance_policy.act(
            handle, state, 0.0)
        action_extra_next_state = self.deadlock_avoidance_policy.act(
            handle, next_state, 0.0)

        extended_state = np.copy(state)
        for action_itr in np.arange(self.action_size):
            extended_state = np.append(extended_state,
                                       [int(action_extra_state == action_itr)])
        extended_next_state = np.copy(next_state)
        for action_itr in np.arange(self.action_size):
            extended_next_state = np.append(
                extended_next_state,
                [int(action_extra_next_state == action_itr)])

        self.deadlock_avoidance_policy.step(handle, state, action, reward,
                                            next_state, done)
        self.ppo_policy.step(handle, extended_state, action, reward,
                             extended_next_state, done)

    def act(self, handle, state, eps=0.):
        action_extra_state = self.deadlock_avoidance_policy.act(
            handle, state, 0.0)
        extended_state = np.copy(state)
        for action_itr in np.arange(self.action_size):
            extended_state = np.append(extended_state,
                                       [int(action_extra_state == action_itr)])
        action_ppo = self.ppo_policy.act(handle, extended_state, eps)
        self.loss = self.ppo_policy.loss
        return action_ppo

    def reset(self, env: RailEnv):
        self.ppo_policy.reset(env)
        self.deadlock_avoidance_policy.reset(env)

    def test(self):
        self.ppo_policy.test()
        self.deadlock_avoidance_policy.test()

    def start_step(self, train):
        self.deadlock_avoidance_policy.start_step(train)
        self.ppo_policy.start_step(train)

    def end_step(self, train):
        self.deadlock_avoidance_policy.end_step(train)
        self.ppo_policy.end_step(train)
コード例 #5
0
def train_agent(train_params, train_env_params, eval_env_params, obs_params):
    # Environment parameters
    n_agents = train_env_params.n_agents
    x_dim = train_env_params.x_dim
    y_dim = train_env_params.y_dim
    n_cities = train_env_params.n_cities
    max_rails_between_cities = train_env_params.max_rails_between_cities
    max_rails_in_city = train_env_params.max_rails_in_city
    seed = train_env_params.seed

    # Unique ID for this training
    now = datetime.now()
    training_id = now.strftime('%y%m%d%H%M%S')

    # Observation parameters
    observation_tree_depth = obs_params.observation_tree_depth
    observation_radius = obs_params.observation_radius
    observation_max_path_depth = obs_params.observation_max_path_depth

    # Training parameters
    eps_start = train_params.eps_start
    eps_end = train_params.eps_end
    eps_decay = train_params.eps_decay
    n_episodes = train_params.n_episodes
    checkpoint_interval = train_params.checkpoint_interval
    n_eval_episodes = train_params.n_evaluation_episodes
    restore_replay_buffer = train_params.restore_replay_buffer
    save_replay_buffer = train_params.save_replay_buffer

    # Set the seeds
    random.seed(seed)
    np.random.seed(seed)

    # Observation builder
    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
    if not train_params.use_fast_tree_observation:
        print("\nUsing standard TreeObs")

        def check_is_observation_valid(observation):
            return observation

        def get_normalized_observation(observation,
                                       tree_depth: int,
                                       observation_radius=0):
            return normalize_observation(observation, tree_depth,
                                         observation_radius)

        tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth,
                                             predictor=predictor)
        tree_observation.check_is_observation_valid = check_is_observation_valid
        tree_observation.get_normalized_observation = get_normalized_observation
    else:
        print("\nUsing FastTreeObs")

        def check_is_observation_valid(observation):
            return True

        def get_normalized_observation(observation,
                                       tree_depth: int,
                                       observation_radius=0):
            return observation

        tree_observation = FastTreeObs(max_depth=observation_tree_depth)
        tree_observation.check_is_observation_valid = check_is_observation_valid
        tree_observation.get_normalized_observation = get_normalized_observation

    # Setup the environments
    train_env = create_rail_env(train_env_params, tree_observation)
    train_env.reset(regenerate_schedule=True, regenerate_rail=True)
    eval_env = create_rail_env(eval_env_params, tree_observation)
    eval_env.reset(regenerate_schedule=True, regenerate_rail=True)

    if not train_params.use_fast_tree_observation:
        # Calculate the state size given the depth of the tree observation and the number of features
        n_features_per_node = train_env.obs_builder.observation_dim
        n_nodes = sum(
            [np.power(4, i) for i in range(observation_tree_depth + 1)])
        state_size = n_features_per_node * n_nodes
    else:
        # Calculate the state size given the depth of the tree observation and the number of features
        state_size = tree_observation.observation_dim

    action_count = [0] * get_flatland_full_action_size()
    action_dict = dict()
    agent_obs = [None] * n_agents
    agent_prev_obs = [None] * n_agents
    agent_prev_action = [2] * n_agents
    update_values = [False] * n_agents

    # Smoothed values used as target for hyperparameter tuning
    smoothed_eval_normalized_score = -1.0
    smoothed_eval_completion = 0.0

    scores_window = deque(
        maxlen=checkpoint_interval)  # todo smooth when rendering instead
    completion_window = deque(maxlen=checkpoint_interval)

    if train_params.action_size == "reduced":
        set_action_size_reduced()
    else:
        set_action_size_full()

    # Double Dueling DQN policy
    if train_params.policy == "DDDQN":
        policy = DDDQNPolicy(state_size, get_action_size(), train_params)
    elif train_params.policy == "PPO":
        policy = PPOPolicy(state_size,
                           get_action_size(),
                           use_replay_buffer=False,
                           in_parameters=train_params)
    elif train_params.policy == "DeadLockAvoidance":
        policy = DeadLockAvoidanceAgent(train_env,
                                        get_action_size(),
                                        enable_eps=False)
    elif train_params.policy == "DeadLockAvoidanceWithDecision":
        # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
        inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params)
        policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size,
                                                    get_action_size(),
                                                    inter_policy)
    elif train_params.policy == "MultiDecision":
        policy = MultiDecisionAgent(state_size, get_action_size(),
                                    train_params)
    else:
        policy = PPOPolicy(state_size,
                           get_action_size(),
                           use_replay_buffer=False,
                           in_parameters=train_params)

    # make sure that at least one policy is set
    if policy is None:
        policy = DDDQNPolicy(state_size, get_action_size(), train_params)

    # Load existing policy
    if train_params.load_policy != "":
        policy.load(train_params.load_policy)

    # Loads existing replay buffer
    if restore_replay_buffer:
        try:
            policy.load_replay_buffer(restore_replay_buffer)
            policy.test()
        except RuntimeError as e:
            print(
                "\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?"
            )
            print(e)
            exit(1)

    print("\n💾 Replay buffer status: {}/{} experiences".format(
        len(policy.memory.memory), train_params.buffer_size))

    hdd = psutil.disk_usage('/')
    if save_replay_buffer and (hdd.free / (2**30)) < 500.0:
        print(
            "⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left."
            .format(hdd.free / (2**30)))

    # TensorBoard writer
    writer = SummaryWriter(comment="_" + train_params.policy + "_" +
                           train_params.action_size)

    training_timer = Timer()
    training_timer.start()

    print(
        "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n"
        .format(train_env.get_num_agents(), x_dim, y_dim, n_episodes,
                n_eval_episodes, checkpoint_interval, training_id))

    for episode_idx in range(n_episodes + 1):
        step_timer = Timer()
        reset_timer = Timer()
        learn_timer = Timer()
        preproc_timer = Timer()
        inference_timer = Timer()

        # Reset environment
        reset_timer.start()
        if train_params.n_agent_fixed:
            number_of_agents = n_agents
            train_env_params.n_agents = n_agents
        else:
            number_of_agents = int(
                min(n_agents, 1 + np.floor(episode_idx / 200)))
            train_env_params.n_agents = episode_idx % number_of_agents + 1

        train_env = create_rail_env(train_env_params, tree_observation)
        obs, info = train_env.reset(regenerate_rail=True,
                                    regenerate_schedule=True)
        policy.reset(train_env)
        reset_timer.end()

        if train_params.render:
            # Setup renderer
            env_renderer = RenderTool(train_env, gl="PGL")
            env_renderer.set_new_rail()

        score = 0
        nb_steps = 0
        actions_taken = []

        # Build initial agent-specific observations
        for agent_handle in train_env.get_agent_handles():
            if tree_observation.check_is_observation_valid(obs[agent_handle]):
                agent_obs[
                    agent_handle] = tree_observation.get_normalized_observation(
                        obs[agent_handle],
                        observation_tree_depth,
                        observation_radius=observation_radius)
                agent_prev_obs[agent_handle] = agent_obs[agent_handle].copy()

        # Max number of steps per episode
        # This is the official formula used during evaluations
        # See details in flatland.envs.schedule_generators.sparse_schedule_generator
        # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
        max_steps = train_env._max_episode_steps

        # Run episode
        policy.start_episode(train=True)
        for step in range(max_steps - 1):
            inference_timer.start()
            policy.start_step(train=True)
            for agent_handle in train_env.get_agent_handles():
                agent = train_env.agents[agent_handle]
                if info['action_required'][agent_handle]:
                    update_values[agent_handle] = True
                    action = policy.act(agent_handle,
                                        agent_obs[agent_handle],
                                        eps=eps_start)
                    action_count[map_action(action)] += 1
                    actions_taken.append(map_action(action))
                else:
                    # An action is not required if the train hasn't joined the railway network,
                    # if it already reached its target, or if is currently malfunctioning.
                    update_values[agent_handle] = False
                    action = 0
                action_dict.update({agent_handle: action})
            policy.end_step(train=True)
            inference_timer.end()

            # Environment step
            step_timer.start()
            next_obs, all_rewards, done, info = train_env.step(
                map_actions(action_dict))
            step_timer.end()

            # Render an episode at some interval
            if train_params.render:
                env_renderer.render_env(show=True,
                                        frames=False,
                                        show_observations=False,
                                        show_predictions=False)

            # Update replay buffer and train agent
            for agent_handle in train_env.get_agent_handles():
                if update_values[agent_handle] or done['__all__']:
                    # Only learn from timesteps where somethings happened
                    learn_timer.start()
                    policy.step(
                        agent_handle, agent_prev_obs[agent_handle],
                        map_action_policy(agent_prev_action[agent_handle]),
                        all_rewards[agent_handle], agent_obs[agent_handle],
                        done[agent_handle])
                    learn_timer.end()

                    agent_prev_obs[agent_handle] = agent_obs[
                        agent_handle].copy()
                    agent_prev_action[agent_handle] = action_dict[agent_handle]

                # Preprocess the new observations
                if tree_observation.check_is_observation_valid(
                        next_obs[agent_handle]):
                    preproc_timer.start()
                    agent_obs[
                        agent_handle] = tree_observation.get_normalized_observation(
                            next_obs[agent_handle],
                            observation_tree_depth,
                            observation_radius=observation_radius)
                    preproc_timer.end()

                score += all_rewards[agent_handle]

            nb_steps = step

            if done['__all__']:
                break

        policy.end_episode(train=True)
        # Epsilon decay
        eps_start = max(eps_end, eps_decay * eps_start)

        # Collect information about training
        tasks_finished = sum(done[idx]
                             for idx in train_env.get_agent_handles())
        completion = tasks_finished / max(1, train_env.get_num_agents())
        normalized_score = score / (max_steps * train_env.get_num_agents())
        action_probs = action_count / max(1, np.sum(action_count))

        scores_window.append(normalized_score)
        completion_window.append(completion)
        smoothed_normalized_score = np.mean(scores_window)
        smoothed_completion = np.mean(completion_window)

        if train_params.render:
            env_renderer.close_window()

        # Print logs
        if episode_idx % checkpoint_interval == 0 and episode_idx > 0:
            policy.save('./checkpoints/' + training_id + '-' +
                        str(episode_idx) + '.pth')

            if save_replay_buffer:
                policy.save_replay_buffer('./replay_buffers/' + training_id +
                                          '-' + str(episode_idx) + '.pkl')

            # reset action count
            action_count = [0] * get_flatland_full_action_size()

        print('\r🚂 Episode {}'
              '\t 🚉 nAgents {:2}/{:2}'
              ' 🏆 Score: {:7.3f}'
              ' Avg: {:7.3f}'
              '\t 💯 Done: {:6.2f}%'
              ' Avg: {:6.2f}%'
              '\t 🎲 Epsilon: {:.3f} '
              '\t 🔀 Action Probs: {}'.format(
                  episode_idx, train_env_params.n_agents, number_of_agents,
                  normalized_score, smoothed_normalized_score,
                  100 * completion, 100 * smoothed_completion, eps_start,
                  format_action_prob(action_probs)),
              end=" ")

        # Evaluate policy and log results at some interval
        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0:
            scores, completions, nb_steps_eval = eval_policy(
                eval_env, tree_observation, policy, train_params, obs_params)

            writer.add_scalar("evaluation/scores_min", np.min(scores),
                              episode_idx)
            writer.add_scalar("evaluation/scores_max", np.max(scores),
                              episode_idx)
            writer.add_scalar("evaluation/scores_mean", np.mean(scores),
                              episode_idx)
            writer.add_scalar("evaluation/scores_std", np.std(scores),
                              episode_idx)
            writer.add_histogram("evaluation/scores", np.array(scores),
                                 episode_idx)
            writer.add_scalar("evaluation/completions_min",
                              np.min(completions), episode_idx)
            writer.add_scalar("evaluation/completions_max",
                              np.max(completions), episode_idx)
            writer.add_scalar("evaluation/completions_mean",
                              np.mean(completions), episode_idx)
            writer.add_scalar("evaluation/completions_std",
                              np.std(completions), episode_idx)
            writer.add_histogram("evaluation/completions",
                                 np.array(completions), episode_idx)
            writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval),
                              episode_idx)
            writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval),
                              episode_idx)
            writer.add_scalar("evaluation/nb_steps_mean",
                              np.mean(nb_steps_eval), episode_idx)
            writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval),
                              episode_idx)
            writer.add_histogram("evaluation/nb_steps",
                                 np.array(nb_steps_eval), episode_idx)

            smoothing = 0.9
            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(
                scores) * (1.0 - smoothing)
            smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(
                completions) * (1.0 - smoothing)
            writer.add_scalar("evaluation/smoothed_score",
                              smoothed_eval_normalized_score, episode_idx)
            writer.add_scalar("evaluation/smoothed_completion",
                              smoothed_eval_completion, episode_idx)

        # Save logs to tensorboard
        writer.add_scalar("training/score", normalized_score, episode_idx)
        writer.add_scalar("training/smoothed_score", smoothed_normalized_score,
                          episode_idx)
        writer.add_scalar("training/completion", np.mean(completion),
                          episode_idx)
        writer.add_scalar("training/smoothed_completion",
                          np.mean(smoothed_completion), episode_idx)
        writer.add_scalar("training/nb_steps", nb_steps, episode_idx)
        writer.add_scalar("training/n_agents", train_env_params.n_agents,
                          episode_idx)
        writer.add_histogram("actions/distribution", np.array(actions_taken),
                             episode_idx)
        writer.add_scalar("actions/nothing",
                          action_probs[RailEnvActions.DO_NOTHING], episode_idx)
        writer.add_scalar("actions/left",
                          action_probs[RailEnvActions.MOVE_LEFT], episode_idx)
        writer.add_scalar("actions/forward",
                          action_probs[RailEnvActions.MOVE_FORWARD],
                          episode_idx)
        writer.add_scalar("actions/right",
                          action_probs[RailEnvActions.MOVE_RIGHT], episode_idx)
        writer.add_scalar("actions/stop",
                          action_probs[RailEnvActions.STOP_MOVING],
                          episode_idx)
        writer.add_scalar("training/epsilon", eps_start, episode_idx)
        writer.add_scalar("training/buffer_size", len(policy.memory),
                          episode_idx)
        writer.add_scalar("training/loss", policy.loss, episode_idx)
        writer.add_scalar("timer/reset", reset_timer.get(), episode_idx)
        writer.add_scalar("timer/step", step_timer.get(), episode_idx)
        writer.add_scalar("timer/learn", learn_timer.get(), episode_idx)
        writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx)
        writer.add_scalar("timer/total", training_timer.get_current(),
                          episode_idx)
        writer.flush()
コード例 #6
0
def cartpole(use_dddqn=False):
    eps = 1.0
    eps_decay = 0.99
    min_eps = 0.01
    training_mode = True

    env = gym.make("CartPole-v1")
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    if not use_dddqn:
        policy = PPOPolicy(observation_space, action_space, False)
    else:
        policy = DDDQNPolicy(observation_space, action_space, dddqn_param)
    episode = 0
    checkpoint_interval = 20
    scores_window = deque(maxlen=100)

    writer = SummaryWriter()

    while True:
        episode += 1
        state = env.reset()
        policy.reset(env)
        handle = 0
        tot_reward = 0

        policy.start_episode(train=training_mode)
        while True:
            # env.render()
            policy.start_step(train=training_mode)
            action = policy.act(handle, state, eps)
            state_next, reward, terminal, info = env.step(action)
            policy.end_step(train=training_mode)
            tot_reward += reward
            # reward = reward if not terminal else -reward
            reward = 0 if not terminal else -1
            policy.step(handle, state, action, reward, state_next, terminal)
            state = np.copy(state_next)
            if terminal:
                break

        policy.end_episode(train=training_mode)
        eps = max(min_eps, eps * eps_decay)
        scores_window.append(tot_reward)
        if episode % checkpoint_interval == 0:
            print(
                '\rEpisode: {:5}\treward: {:7.3f}\t avg: {:7.3f}\t eps: {:5.3f}\t replay buffer: {}'
                .format(episode, tot_reward, np.mean(scores_window), eps,
                        len(policy.memory)))
        else:
            print(
                '\rEpisode: {:5}\treward: {:7.3f}\t avg: {:7.3f}\t eps: {:5.3f}\t replay buffer: {}'
                .format(episode, tot_reward, np.mean(scores_window), eps,
                        len(policy.memory)),
                end=" ")

        writer.add_scalar("CartPole/value", tot_reward, episode)
        writer.add_scalar("CartPole/smoothed_value", np.mean(scores_window),
                          episode)
        writer.flush()
class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
    def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
        print(">> DeadLockAvoidanceWithDecisionAgent")
        super(DeadLockAvoidanceWithDecisionAgent, self).__init__()
        self.env = env
        self.state_size = state_size
        self.action_size = action_size
        self.learning_agent = learning_agent
        self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(
            self.env, action_size, False)
        self.policy_selector = PPOPolicy(state_size, 2)

        self.memory = self.learning_agent.memory
        self.loss = self.learning_agent.loss

    def step(self, handle, state, action, reward, next_state, done):
        select = self.policy_selector.act(handle, state, 0.0)
        self.policy_selector.step(handle, state, select, reward, next_state,
                                  done)
        self.dead_lock_avoidance_agent.step(handle, state, action, reward,
                                            next_state, done)
        self.learning_agent.step(handle, state, action, reward, next_state,
                                 done)
        self.loss = self.learning_agent.loss

    def act(self, handle, state, eps=0.):
        select = self.policy_selector.act(handle, state, eps)
        if select == 0:
            return self.learning_agent.act(handle, state, eps)
        return self.dead_lock_avoidance_agent.act(handle, state, -1.0)

    def save(self, filename):
        self.dead_lock_avoidance_agent.save(filename)
        self.learning_agent.save(filename)
        self.policy_selector.save(filename + '.selector')

    def load(self, filename):
        self.dead_lock_avoidance_agent.load(filename)
        self.learning_agent.load(filename)
        self.policy_selector.load(filename + '.selector')

    def start_step(self, train):
        self.dead_lock_avoidance_agent.start_step(train)
        self.learning_agent.start_step(train)
        self.policy_selector.start_step(train)

    def end_step(self, train):
        self.dead_lock_avoidance_agent.end_step(train)
        self.learning_agent.end_step(train)
        self.policy_selector.end_step(train)

    def start_episode(self, train):
        self.dead_lock_avoidance_agent.start_episode(train)
        self.learning_agent.start_episode(train)
        self.policy_selector.start_episode(train)

    def end_episode(self, train):
        self.dead_lock_avoidance_agent.end_episode(train)
        self.learning_agent.end_episode(train)
        self.policy_selector.end_episode(train)

    def load_replay_buffer(self, filename):
        self.dead_lock_avoidance_agent.load_replay_buffer(filename)
        self.learning_agent.load_replay_buffer(filename)
        self.policy_selector.load_replay_buffer(filename + ".selector")

    def test(self):
        self.dead_lock_avoidance_agent.test()
        self.learning_agent.test()
        self.policy_selector.test()

    def reset(self, env: RailEnv):
        self.env = env
        self.dead_lock_avoidance_agent.reset(env)
        self.learning_agent.reset(env)
        self.policy_selector.reset(env)

    def clone(self):
        return self
コード例 #8
0
    local_env = remote_client.env
    nb_agents = len(local_env.agents)
    max_nb_steps = local_env._max_episode_steps

    tree_observation.set_env(local_env)
    tree_observation.reset()

    # Creates the policy. No GPU on evaluation server.
    if load_policy == "DDDQN":
        policy = DDDQNPolicy(state_size,
                             get_action_size(),
                             Namespace(**{'use_gpu': False}),
                             evaluation_mode=True)
    elif load_policy == "PPO":
        policy = PPOPolicy(state_size, get_action_size())
    elif load_policy == "DeadLockAvoidance":
        policy = DeadLockAvoidanceAgent(local_env,
                                        get_action_size(),
                                        enable_eps=False)
    elif load_policy == "DeadLockAvoidanceWithDecision":
        # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
        inter_policy = DDDQNPolicy(state_size,
                                   get_action_size(),
                                   Namespace(**{'use_gpu': False}),
                                   evaluation_mode=True)
        policy = DeadLockAvoidanceWithDecisionAgent(local_env, state_size,
                                                    get_action_size(),
                                                    inter_policy)
    elif load_policy == "MultiDecision":
        policy = MultiDecisionAgent(state_size, get_action_size(),
コード例 #9
0
class MultiDecisionAgent(LearningPolicy):
    def __init__(self, state_size, action_size, in_parameters=None):
        print(">> MultiDecisionAgent")
        super(MultiDecisionAgent, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.in_parameters = in_parameters
        self.memory = DummyMemory()
        self.loss = 0

        self.ppo_policy = PPOPolicy(state_size,
                                    action_size,
                                    use_replay_buffer=False,
                                    in_parameters=in_parameters)
        self.dddqn_policy = DDDQNPolicy(state_size, action_size, in_parameters)
        self.policy_selector = PPOPolicy(state_size, 2)

    def step(self, handle, state, action, reward, next_state, done):
        self.ppo_policy.step(handle, state, action, reward, next_state, done)
        self.dddqn_policy.step(handle, state, action, reward, next_state, done)
        select = self.policy_selector.act(handle, state, 0.0)
        self.policy_selector.step(handle, state, select, reward, next_state,
                                  done)

    def act(self, handle, state, eps=0.):
        select = self.policy_selector.act(handle, state, eps)
        if select == 0:
            return self.dddqn_policy.act(handle, state, eps)
        return self.policy_selector.act(handle, state, eps)

    def save(self, filename):
        self.ppo_policy.save(filename)
        self.dddqn_policy.save(filename)
        self.policy_selector.save(filename)

    def load(self, filename):
        self.ppo_policy.load(filename)
        self.dddqn_policy.load(filename)
        self.policy_selector.load(filename)

    def start_step(self, train):
        self.ppo_policy.start_step(train)
        self.dddqn_policy.start_step(train)
        self.policy_selector.start_step(train)

    def end_step(self, train):
        self.ppo_policy.end_step(train)
        self.dddqn_policy.end_step(train)
        self.policy_selector.end_step(train)

    def start_episode(self, train):
        self.ppo_policy.start_episode(train)
        self.dddqn_policy.start_episode(train)
        self.policy_selector.start_episode(train)

    def end_episode(self, train):
        self.ppo_policy.end_episode(train)
        self.dddqn_policy.end_episode(train)
        self.policy_selector.end_episode(train)

    def load_replay_buffer(self, filename):
        self.ppo_policy.load_replay_buffer(filename)
        self.dddqn_policy.load_replay_buffer(filename)
        self.policy_selector.load_replay_buffer(filename)

    def test(self):
        self.ppo_policy.test()
        self.dddqn_policy.test()
        self.policy_selector.test()

    def reset(self, env: RailEnv):
        self.ppo_policy.reset(env)
        self.dddqn_policy.reset(env)
        self.policy_selector.reset(env)

    def clone(self):
        multi_descision_agent = MultiDecisionAgent(self.state_size,
                                                   self.action_size,
                                                   self.in_parameters)
        multi_descision_agent.ppo_policy = self.ppo_policy.clone()
        multi_descision_agent.dddqn_policy = self.dddqn_policy.clone()
        multi_descision_agent.policy_selector = self.policy_selector.clone()
        return multi_descision_agent