class RLlibTFA2Policy(AgentPolicy): def __init__(self, load_path, algorithm, policy_name, observation_space, action_space): self._checkpoint_path = load_path self._algorithm = algorithm self._policy_name = policy_name self._observation_space = observation_space self._action_space = action_space self._sess = None if isinstance(action_space, gym.spaces.Box): self.is_continuous = True elif isinstance(action_space, gym.spaces.Discrete): self.is_continuous = False else: raise TypeError("Unsupport action space") if self._sess: return if self._algorithm == "PPO": from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy as LoadPolicy elif self._algorithm in ["A2C", "A3C"]: from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy as LoadPolicy elif self._algorithm == "PG": from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy as LoadPolicy elif self._algorithm == "DQN": from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy as LoadPolicy else: raise TypeError("Unsupport algorithm") self._prep = ModelCatalog.get_preprocessor_for_space( self._observation_space) self._sess = tf.Session(graph=tf.Graph()) self._sess.__enter__() with tf.name_scope(self._policy_name): # obs_space need to be flattened before passed to PPOTFPolicy flat_obs_space = self._prep.observation_space self.policy = LoadPolicy(flat_obs_space, self._action_space, {}) objs = pickle.load(open(self._checkpoint_path, "rb")) objs = pickle.loads(objs["worker"]) state = objs["state"] weights = state[self._policy_name] self.policy.set_weights(weights) def act(self, obs): if isinstance(obs, list): # batch infer obs = [self._prep.transform(o) for o in obs] action = self.policy.compute_actions(obs, explore=False)[0] else: # single infer obs = self._prep.transform(obs) action = self.policy.compute_actions([obs], explore=False)[0][0] return action
class RLAgent(Agent): def __init__(self, load_path, policy_name, observation_space, action_space): self._checkpoint_path = load_path self._policy_name = policy_name self._observation_space = observation_space self._action_space = action_space self._sess = None if isinstance(action_space, gym.spaces.Box): self.is_continuous = True elif isinstance(action_space, gym.spaces.Discrete): self.is_continuous = False else: raise TypeError("Unsupport action space") if self._sess: return self._prep = ModelCatalog.get_preprocessor_for_space( self._observation_space) self._sess = tf.compat.v1.Session(graph=tf.Graph()) self._sess.__enter__() with tf.name_scope(self._policy_name): # obs_space need to be flattened before passed to PPOTFPolicy flat_obs_space = self._prep.observation_space self.policy = LoadPolicy(flat_obs_space, self._action_space, {}) objs = pickle.load(open(self._checkpoint_path, "rb")) objs = pickle.loads(objs["worker"]) state = objs["state"] weights = state[self._policy_name] self.policy.set_weights(weights) def act(self, obs): if isinstance(obs, list): # batch infer obs = [self._prep.transform(o) for o in obs] action = self.policy.compute_actions(obs, explore=False)[0] else: # single infer obs = self._prep.transform(obs) action = self.policy.compute_actions([obs], explore=False)[0][0] return action
class RLlibTFCheckpointPolicy(AgentPolicy): def __init__(self, load_path, algorithm, policy_name, observation_space, action_space): self._checkpoint_path = load_path self._algorithm = algorithm self._policy_name = policy_name self._observation_space = observation_space self._action_space = action_space self._sess = None if isinstance(action_space, gym.spaces.Box): self.is_continuous = True elif isinstance(action_space, gym.spaces.Discrete): self.is_continuous = False else: raise TypeError("Unsupport action space") if self._sess: return if self._algorithm == "PPO": from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy as LoadPolicy elif self._algorithm in ["A2C", "A3C"]: from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy as LoadPolicy elif self._algorithm == "PG": from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy as LoadPolicy elif self._algorithm == "DQN": from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy as LoadPolicy else: raise TypeError("Unsupport algorithm") self._prep = ModelCatalog.get_preprocessor_for_space( self._observation_space) self._sess = tf.Session(graph=tf.Graph()) self._sess.__enter__() import ray.rllib.agents.ppo as ppo config = ppo.DEFAULT_CONFIG.copy() config['num_workers'] = 0 config["model"]["use_lstm"] = True with tf.name_scope(self._policy_name): # obs_space need to be flattened before passed to PPOTFPolicy flat_obs_space = self._prep.observation_space self.policy = LoadPolicy(flat_obs_space, self._action_space, config) objs = pickle.load(open(self._checkpoint_path, "rb")) objs = pickle.loads(objs["worker"]) state = objs["state"] filters = objs["filters"] self.filters = filters[self._policy_name] weights = state[self._policy_name] self.policy.set_weights(weights) self.model = self.policy.model # print(self.model.summary()) self.rnn_state = self.model.get_initial_state() self.rnn_state = [[self.rnn_state[0]], [self.rnn_state[1]]] def act(self, obs): if isinstance(obs, list): # batch infer obs = [self._prep.transform(o) for o in obs] obs = self.filters(obs, update=False) action = self.policy.compute_actions(obs, explore=False)[0] else: # single infer obs = self._prep.transform(obs) obs = self.filters(obs, update=False) action, self.rnn_state, _ = self.policy.compute_actions( [obs], self.rnn_state, explore=False) action = action[0] return action