class RLlibTorchGRUMultiPolicy(AgentPolicy): def __init__(self, load_path, algorithm, policy_name, observation_space, action_space): self._checkpoint_path = load_path self._policy_name = policy_name self._observation_space = observation_space self._action_space = action_space self._prep = ModelCatalog.get_preprocessor_for_space(self._observation_space) flat_obs_space = self._prep.observation_space ray.init(ignore_reinit_error=True, local_mode=True) from utils.ppo_policy import PPOTorchPolicy as LoadPolicy from utils.rnn_model import RNNMultiModel ModelCatalog.register_custom_model("my_rnn", RNNMultiModel) config = ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG.copy() config["vf_share_layers"] = True config['num_workers'] = 0 config["model"]["custom_model"] = "my_rnn" config['model']['free_log_std'] = False self.policy = LoadPolicy(flat_obs_space, self._action_space, config) objs = pickle.load(open(self._checkpoint_path, "rb")) objs = pickle.loads(objs["worker"]) state = objs["state"] filters = objs["filters"] self.filters = filters[self._policy_name] weights = state[self._policy_name] weights.pop("_optimizer_variables") self.policy.set_weights(weights) self.model = self.policy.model self.rnn_state = self.model.get_initial_state() self.rnn_state = [torch.reshape(self.rnn_state[0], shape=(1, -1))] def act(self, obs): # single infer obs = self._prep.transform(obs) obs = self.filters(obs, update=False) action, self.rnn_state, _ = self.policy.compute_actions([obs], self.rnn_state, explore=False) action = action[0] return action
class RLlibTorchFCDyskipPolicy(AgentPolicy): def __init__(self, load_path, algorithm, policy_name, observation_space, action_space): self._checkpoint_path = load_path self._policy_name = policy_name self._observation_space = observation_space self._action_space = action_space self._prep = ModelCatalog.get_preprocessor_for_space(self._observation_space) flat_obs_space = self._prep.observation_space ray.init(ignore_reinit_error=True, local_mode=True) from utils.ppo_policy import PPOTorchPolicy as LoadPolicy from utils.fc_model import FullyConnectedNetwork ModelCatalog.register_custom_model("my_fc", FullyConnectedNetwork) ModelCatalog.register_custom_action_dist("my_dist", TorchDyDistribution) config = ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG.copy() config['num_workers'] = 0 config["model"]["custom_model"] = "my_fc" config['model']['free_log_std'] = False config["model"]["custom_action_dist"] = "my_dist" self.policy = LoadPolicy(flat_obs_space, self._action_space, config) objs = pickle.load(open(self._checkpoint_path, "rb")) objs = pickle.loads(objs["worker"]) state = objs["state"] filters = objs["filters"] self.filters = filters[self._policy_name] weights = state[self._policy_name] weights.pop("_optimizer_variables") self.policy.set_weights(weights) self.model = self.policy.model def act(self, obs): # single infer obs = self._prep.transform(obs) obs = self.filters(obs, update=False) action, _, _ = self.policy.compute_actions([obs], explore=False) action = action[0] return action