Esempio n. 1
0
def ml_reset(port):
    global client_mc
    host_ip = "http://localhost:" + port
    print("Connected to policy server: ", host_ip)
    client_mc = PolicyClient(host_ip,
                             inference_mode='local',
                             update_interval=None)
Esempio n. 2
0
    def __init__(self, agent_shared_state,
                 agent_model_path, interfaces, domain):
        super().__init__(agent_shared_state, agent_model_path, interfaces, domain)

        self.interfaces = interfaces

        if "rl_config" not in agent_shared_state:
            if agent_model_path:
                with open(agent_model_path, "r") as f:
                    agent_shared_state["rl_config"] = json.load(f)
            else:
                agent_shared_state["rl_config"] = DEFAULT_CONFIG
            agent_shared_state["action_space_partitions"] = action_space_partitions(
                    domain, agent_shared_state["rl_config"]["state_space"]["max_variables"], True)
            custom_model = agent_shared_state["rl_config"].get("custom_model", None)
            if custom_model:
                register_custom_model(custom_model)
            assert domain.name == agent_shared_state["rl_config"]["domain"], \
                "Domain mismatch %s != %s" % (domain.name, agent_shared_state["rl_config"]["domain"])

        config = agent_shared_state["rl_config"]
        self.state = DialogState(
                config["state_space"]["max_steps"],
                config["state_space"]["max_utterances"],
                config["state_space"]["max_variables"],
                config["state_space"]["embed_dim"],
                config["state_space"]["max_command_length"],
                config["state_space"]["max_conversation_length"],
                # get the number of actions for the user
                len(action_space_partitions(
                        domain,
                        agent_shared_state["rl_config"]["state_space"]["max_variables"], False)),
                len(agent_shared_state["action_space_partitions"]))

        self.policy_client = PolicyClient(
                "http://%s:%d" % (config["policy_server_host"],
                                  int(config["policy_server_port"])),
                inference_mode=config["inference_mode"])
Esempio n. 3
0
parser.add_argument(
    "--inference-mode", type=str, required=True, choices=["local", "remote"])
parser.add_argument(
    "--off-policy",
    action="store_true",
    help="Whether to take random instead of on-policy actions.")
parser.add_argument(
    "--stop-reward",
    type=int,
    default=9999,
    help="Stop once the specified reward is reached.")

if __name__ == "__main__":
    args = parser.parse_args()
    env = gym.make("CartPole-v0")
    client = PolicyClient(
        "http://localhost:9900", inference_mode=args.inference_mode)

    eid = client.start_episode(training_enabled=not args.no_train)
    obs = env.reset()
    rewards = 0

    while True:
        if args.off_policy:
            action = env.action_space.sample()
            client.log_action(eid, obs, action)
        else:
            action = client.get_action(eid, obs)
        obs, reward, done, info = env.step(action)
        rewards += reward
        client.log_returns(eid, reward, info=info)
        if done:
Esempio n. 4
0
)
parser.add_argument(
    "--stop-reward",
    type=float,
    default=9999,
    help="Stop once the specified reward is reached.",
)

if __name__ == "__main__":
    args = parser.parse_args()

    # Start the client for sending environment information (e.g. observations,
    # actions) to a policy server (listening on port 9900).
    client = PolicyClient(
        "http://" + args.server + ":" + str(args.port),
        inference_mode=args.inference_mode,
        update_interval=args.update_interval_local_mode,
    )

    # Start and reset the actual Unity3DEnv (either already running Unity3D
    # editor or a binary (game) to be started automatically).
    env = Unity3DEnv(file_name=args.game, episode_horizon=args.horizon)
    obs = env.reset()
    eid = client.start_episode(training_enabled=not args.no_train)

    # Keep track of the total reward per episode.
    total_rewards_this_episode = 0.0

    # Loop infinitely through the env.
    while True:
        # Get actions from the Policy server given our current obs.
Esempio n. 5
0
def env_boost(ind=0):
    env = create_env()
    not_ready_flag = True
    while not_ready_flag:
        try:
            client = PolicyClient("http://{}:{}".format(
                SERVER_ADDRESS, SERVER_PORT + ind),
                                  inference_mode=FLAGS.inference_mode)

            eid = client.start_episode(training_enabled=True)
        except ConnectionError:
            print("Server not ready...")
        else:
            not_ready_flag = False
    obs = env.reset()
    rewards = 0

    while True:
        if str2bool(FLAGS.off_policy):
            action = env.action_space.sample()
            client.log_action(eid, obs, action)
        else:
            st = time.time()
            action = client.get_action(eid, obs)
            print("get action: ", eid, action)
            print('proc_time: {}'.format(time.time() - st))
        obs, reward, done, info = env.step(action)
        rewards += reward
        client.log_returns(eid, reward, info=info)
        print("log returns: ", eid, reward)
        if done:
            print("Total reward:", rewards)
            rewards = 0
            client.end_episode(eid, obs)
            obs = env.reset()
            eid = client.start_episode(training_enabled=not FLAGS.no_train)
Esempio n. 6
0
class RLAgent(GenericAgent):
    @overrides
    def __init__(self, agent_shared_state,
                 agent_model_path, interfaces, domain):
        super().__init__(agent_shared_state, agent_model_path, interfaces, domain)

        self.interfaces = interfaces

        if "rl_config" not in agent_shared_state:
            if agent_model_path:
                with open(agent_model_path, "r") as f:
                    agent_shared_state["rl_config"] = json.load(f)
            else:
                agent_shared_state["rl_config"] = DEFAULT_CONFIG
            agent_shared_state["action_space_partitions"] = action_space_partitions(
                    domain, agent_shared_state["rl_config"]["state_space"]["max_variables"], True)
            custom_model = agent_shared_state["rl_config"].get("custom_model", None)
            if custom_model:
                register_custom_model(custom_model)
            assert domain.name == agent_shared_state["rl_config"]["domain"], \
                "Domain mismatch %s != %s" % (domain.name, agent_shared_state["rl_config"]["domain"])

        config = agent_shared_state["rl_config"]
        self.state = DialogState(
                config["state_space"]["max_steps"],
                config["state_space"]["max_utterances"],
                config["state_space"]["max_variables"],
                config["state_space"]["embed_dim"],
                config["state_space"]["max_command_length"],
                config["state_space"]["max_conversation_length"],
                # get the number of actions for the user
                len(action_space_partitions(
                        domain,
                        agent_shared_state["rl_config"]["state_space"]["max_variables"], False)),
                len(agent_shared_state["action_space_partitions"]))

        self.policy_client = PolicyClient(
                "http://%s:%d" % (config["policy_server_host"],
                                  int(config["policy_server_port"])),
                inference_mode=config["inference_mode"])

    def send_reward_to_policy_server(self, reward):
        assert isinstance(reward, dict), "reward should be a dict but it's this: %s" % reward
        if not self.agent_shared_state["rl_config"]["multiagent"]:
            reward = reward["agent"]
        self.policy_client.log_returns(self.current_episode_id, reward)

    def make_observation(self):
        obs = self.state.make_driver_observation(self._get_valid_actions_mask())
        if not self.agent_shared_state["rl_config"]["state_space"]["include_steps_in_obs_space"] \
                and "steps" in obs:
            del obs["steps"]

        if self.agent_shared_state["rl_config"]["multiagent"]:
            obs = {"agent": obs}
        return obs

    def get_action(self):
        ''' Returns the action index for the driver '''
        action = self.policy_client.get_action(self.current_episode_id, self.make_observation())
        if self.agent_shared_state["rl_config"]["multiagent"]:
            assert len(action) == 1 and "agent" in action, "Bad action_dict: %s" % str(action)
            return action["agent"]
        else:
            return action

    @overrides
    def on_dialog_finished(self, completed, correct):
        '''
        Called at the end of a dialog. 
        If the dialog is restarted before completion, 'completed' is set to False.
        '''
        reward_function = self.agent_shared_state["rl_config"]["rewards"]
        if self.state.is_done(): # early termination
            agent_reward = reward_function["driver_max_steps"]
        else:
            agent_reward = reward_function["driver_correct_destination"] if correct \
                           else reward_function["driver_incorrect_destination"]
        
        user_reward = None
        if not isinstance(self.room.user, HumanUser):
            if self.state.is_done(): # early termination
                user_reward = reward_function["passenger_max_steps"]
            else:
                user_reward = reward_function["passenger_correct_destination"] if correct\
                             else reward_function["passenger_incorrect_destination"]

        # print("DEBUG: Dialog was finished with reward", agent_reward + user_reward)
        if user_reward is not None:
            joint_reward = {"agent": agent_reward, "user": user_reward}
        else:
            joint_reward = {"agent": agent_reward}
        self.send_reward_to_policy_server(joint_reward)
            
        if self.agent_shared_state["rl_config"]["multiagent"]:
            final_obs = {
                **self.make_observation(),
                **(self.room.user.make_observation() if not isinstance(self.room.user, HumanUser) else {})
            }
        else:
            final_obs = self.make_observation()
        self.policy_client.end_episode(
            self.current_episode_id, final_obs)

    @overrides
    def reset(self, initial_variables):
        ''' Actions to perform when a dialog starts. '''
        super().reset(initial_variables)
        self.state.reset()

        self.current_episode_id = self.policy_client.start_episode(
            training_enabled=self.agent_shared_state["rl_config"]["policy_client_training_enabled"])
        # print("DEBUG started episode", self.current_episode_id)

    def _get_valid_actions_mask(self):
        '''
        Return indicator mask of valid actions for driver

        TODO use action partition indices instead of hard-coded ordering if different action types.
        '''
        mask = [0] * len(self.agent_shared_state["action_space_partitions"])
        assert len(mask) == self.state.driver_max_actions, \
                "%d != %d" % (len(mask) == len(self.state.driver_max_actions))
        empty_command = len(self.state.driver_partial_command) == 0
        for index, action_type_and_sub_index in self.agent_shared_state["action_space_partitions"].items():
            action_type, sub_index = action_type_and_sub_index
            if empty_command and action_type in \
                    (ActionType.API, ActionType.END_DIALOG_API, ActionType.TEMPLATE): # TODO add ActionType.NO_ACTION
                mask[index] = 1
            elif not empty_command and action_type == ActionType.VARIABLE and \
                    sub_index < len(self.state.driver_variables):
                variable = self.state.driver_variables[sub_index]
                variable_name = variable.get('full_name', '')
                if self.state.driver_partial_command[0] in ('/maps/find_place', '/maps/places_nearby'):
                    if (len(self.state.driver_partial_command) == 1 # Query parameter, must click a user utterance
                            and variable_name.startswith('u')) or \
                       (len(self.state.driver_partial_command) == 2 and (variable_name.startswith('u') or 'latitude' in variable_name)) or \
                       (len(self.state.driver_partial_command) == 3 and 'longitude' in variable_name):
                        mask[index] = 1
                elif self.state.driver_partial_command[0] == '/maps/start_driving_no_map':
                    if (len(self.state.driver_partial_command) == 1 and 'latitude' in variable_name) or \
                       (len(self.state.driver_partial_command) == 2 and 'longitude' in variable_name):
                        mask[index] = 1
                else:
                    # Template variable
                    mask[index] = 1
        return mask

    def _step(self, action_index):
        '''
        Steps the agent one-click forward and executes apis/templates if a full command results from the click.
        Args: 
            action_index - index corresponding to the item-to-click (api/template/variable)
        Returns:
            (reward, message) tuple.  Message corresponds to the agent message to respond with, it may be None.
        '''
        raise NotImplementedError

    def on_message(self, message: str, events: dict) -> dict:
        agent_message = None
        while not agent_message:
            if self.state.is_done():
                self.room.manager.end_dialog(False, None, 'give_up', {})
                break
            cur_state = self.room.state()
            self.state.update_state(cur_state)
            action_index = self.get_action()
            reward, agent_message, ended_dialog = self._step(action_index)
            self.send_reward_to_policy_server(reward)
            if ended_dialog:
                break
        return agent_message
Esempio n. 7
0
apt_list = None
rwy_list = None
valid_hdg = None
cone_centre = None
lat_cone_list = None
lon_cone_list = None
# landed_count = 0
# crashed_count = 0
# semi_landed_count = 0
lat_eham = 52.18
lon_eham = 4.45
crash_count = None

host_ip = "http://localhost:27800"
print("Connected to policy server: ", host_ip)
client_mc = PolicyClient(host_ip, inference_mode='local', update_interval=None)

### Initialization function of your plugin. Do not change the name of this
### function, as it is the way BlueSky recognises this file as a plugin.

# Additional initilisation code


def init_plugin():
    # global client_mc
    # client_mc = PolicyClient("http://localhost:27802")

    # Configuration parameters
    config = {
        # The name of your plugin
        'plugin_name': 'MLCONTROLC',
Esempio n. 8
0
                    default=9900,
                    help="The port to use (on localhost).")
parser.add_argument("--dummy-arg", type=str, default="")

if __name__ == "__main__":
    args = parser.parse_args()

    ray.init()

    # Use a CartPole-v0 env so this plays nicely with our cartpole server script.
    env = gym.make("CartPole-v0")

    # Note that the RolloutWorker that is generated inside the client (in case
    # of local inference) will contain only a RandomEnv dummy env to step through.
    # The actual env we care about is the above generated CartPole one.
    client = PolicyClient(f"http://localhost:{args.port}",
                          inference_mode=args.inference_mode)

    # Get a dummy obs
    dummy_obs = env.reset()
    dummy_reward = 1.3

    # Start an episode to only compute actions (do NOT record this episode's
    # trajectories in any returned SampleBatches sent to the server for learning).
    action_eid = client.start_episode(training_enabled=False)
    print(f"Starting action episode: {action_eid}.")
    # Get some actions using the action episode
    dummy_action = client.get_action(action_eid, dummy_obs)
    print(f"Computing action 1 in action episode: {dummy_action}.")
    dummy_action = client.get_action(action_eid, dummy_obs)
    print(f"Computing action 2 in action episode: {dummy_action}.")