Beispiel #1
0
class RLlibTFA2FilterPolicy(AgentPolicy):
    def __init__(self, load_path, algorithm, policy_name, observation_space,
                 action_space):
        self._checkpoint_path = load_path
        self._algorithm = algorithm
        self._policy_name = policy_name
        self._observation_space = observation_space
        self._action_space = action_space
        self._sess = None

        if isinstance(action_space, gym.spaces.Box):
            self.is_continuous = True
        elif isinstance(action_space, gym.spaces.Discrete):
            self.is_continuous = False
        else:
            raise TypeError("Unsupport action space")

        if self._sess:
            return

        if self._algorithm == "PPO":
            from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy as LoadPolicy
        elif self._algorithm in ["A2C", "A3C"]:
            from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy as LoadPolicy
        elif self._algorithm == "PG":
            from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy as LoadPolicy
        elif self._algorithm == "DQN":
            from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy as LoadPolicy
        else:
            raise TypeError("Unsupport algorithm")

        self._prep = ModelCatalog.get_preprocessor_for_space(
            self._observation_space)
        self._sess = tf.Session(graph=tf.Graph())
        self._sess.__enter__()

        with tf.name_scope(self._policy_name):
            # obs_space need to be flattened before passed to PPOTFPolicy
            flat_obs_space = self._prep.observation_space
            config = ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG.copy()
            config['num_workers'] = 0
            config['model']['free_log_std'] = True
            self.policy = LoadPolicy(flat_obs_space, self._action_space,
                                     config)
            objs = pickle.load(open(self._checkpoint_path, "rb"))
            objs = pickle.loads(objs["worker"])
            state = objs["state"]
            filters = objs["filters"]
            self.filters = filters[self._policy_name]
            weights = state[self._policy_name]
            self.policy.set_weights(weights)

    def act(self, obs):

        # single infer
        obs = self._prep.transform(obs)
        obs = self.filters(obs, update=False)
        action = self.policy.compute_actions([obs], explore=False)[0][0]

        return action
Beispiel #2
0
    def __init__(
        self, load_path, algorithm, policy_name, observation_space, action_space
    ):
        load_path = str(load_path)
        if algorithm == "PPO":
            from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy as LoadPolicy
        elif algorithm in ["A2C", "A3C"]:
            from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy as LoadPolicy
        elif algorithm == "PG":
            from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy as LoadPolicy
        elif algorithm == "DQN":
            from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy as LoadPolicy
        else:
            raise ValueError(f"Unsupported algorithm: {algorithm}")

        self._prep = ModelCatalog.get_preprocessor_for_space(observation_space)
        self._sess = tf.compat.v1.Session(graph=tf.Graph())

        with tf.compat.v1.name_scope(policy_name):
            # obs_space need to be flattened before passed to PPOTFPolicy
            flat_obs_space = self._prep.observation_space
            policy = LoadPolicy(flat_obs_space, self._action_space, {})
            objs = pickle.load(open(load_path, "rb"))
            objs = pickle.loads(objs["worker"])
            state = objs["state"]
            weights = state[policy_name]
            policy.set_weights(weights)

        # These tensor names were found by inspecting the trained model
        if algorithm == "PPO":
            # CRUCIAL FOR SAFETY:
            #   We use Tensor("split") instead of Tensor("add") to force
            #   PPO to be deterministic.
            self._input_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/observation:0"
            )
            self._output_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/split:0"
            )
        elif self._algorithm == "DQN":
            self._input_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/observations:0"
            )
            self._output_node = tf.argmax(
                input=self._sess.graph.get_tensor_by_name(
                    f"{policy_name}/value_out/BiasAdd:0"
                ),
                axis=1,
            )
        else:
            self._input_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/observations:0"
            )
            self._output_node = tf.argmax(
                input=self._sess.graph.get_tensor_by_name(
                    f"{policy_name}/fc_out/BiasAdd:0"
                ),
                axis=1,
            )
class RLlibTFCheckpointPolicy(AgentPolicy):
    def __init__(self, load_path, algorithm, policy_name, observation_space,
                 action_space):
        self._load_path = load_path
        self._algorithm = algorithm
        self._policy_name = policy_name
        self._observation_space = observation_space
        self._action_space = action_space
        self._sess = None

        if isinstance(action_space, gym.spaces.Box):
            self.is_continuous = True
        elif isinstance(action_space, gym.spaces.Discrete):
            self.is_continuous = False
        else:
            raise TypeError("Unsupport action space")

    def setup(self):
        if self._sess:
            return

        if self._algorithm == "PPO":
            from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy as LoadPolicy
        elif self._algorithm in ["A2C", "A3C"]:
            from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy as LoadPolicy
        elif self._algorithm == "PG":
            from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy as LoadPolicy
        elif self._algorithm == "DQN":
            from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy as LoadPolicy
        else:
            raise TypeError("Unsupport algorithm")

        self._prep = ModelCatalog.get_preprocessor_for_space(
            self._observation_space)
        self._sess = tf.Session(graph=tf.Graph())
        self._sess.__enter__()

        with tf.name_scope(self._policy_name):
            # obs_space need to be flattened before passed to PPOTFPolicy
            flat_obs_space = self._prep.observation_space
            self.policy = LoadPolicy(flat_obs_space, self._action_space, {})
            objs = pickle.load(open(self._load_path, "rb"))
            objs = pickle.loads(objs["worker"])
            state = objs["state"]
            weights = state[self._policy_name]
            self.policy.set_weights(weights)

    def teardown(self):
        # TODO: actually teardown the TF session
        pass

    def act(self, obs):
        obs = self._prep.transform(obs)
        action = self.policy.compute_actions([obs], explore=False)[0][0]

        return action
Beispiel #4
0
    def __init__(self, load_path, policy_name, observation_space,
                 action_space):
        self._checkpoint_path = load_path
        self._policy_name = policy_name
        self._observation_space = observation_space
        self._action_space = action_space
        self._sess = None

        if isinstance(action_space, gym.spaces.Box):
            self.is_continuous = True
        elif isinstance(action_space, gym.spaces.Discrete):
            self.is_continuous = False
        else:
            raise TypeError("Unsupport action space")

        if self._sess:
            return

        self._prep = ModelCatalog.get_preprocessor_for_space(
            self._observation_space)
        self._sess = tf.compat.v1.Session(graph=tf.Graph())
        self._sess.__enter__()

        with tf.name_scope(self._policy_name):
            # obs_space need to be flattened before passed to PPOTFPolicy
            flat_obs_space = self._prep.observation_space
            self.policy = LoadPolicy(flat_obs_space, self._action_space, {})
            objs = pickle.load(open(self._checkpoint_path, "rb"))
            objs = pickle.loads(objs["worker"])
            state = objs["state"]
            weights = state[self._policy_name]
            self.policy.set_weights(weights)
    def setup(self):
        if self._sess:
            return

        if self._algorithm == "PPO":
            from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy as LoadPolicy
        elif self._algorithm in ["A2C", "A3C"]:
            from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy as LoadPolicy
        elif self._algorithm == "PG":
            from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy as LoadPolicy
        elif self._algorithm == "DQN":
            from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy as LoadPolicy
        else:
            raise TypeError("Unsupport algorithm")

        self._prep = ModelCatalog.get_preprocessor_for_space(
            self._observation_space)
        self._sess = tf.Session(graph=tf.Graph())
        self._sess.__enter__()

        with tf.name_scope(self._policy_name):
            # obs_space need to be flattened before passed to PPOTFPolicy
            flat_obs_space = self._prep.observation_space
            self.policy = LoadPolicy(flat_obs_space, self._action_space, {})
            objs = pickle.load(open(self._load_path, "rb"))
            objs = pickle.loads(objs["worker"])
            state = objs["state"]
            weights = state[self._policy_name]
            self.policy.set_weights(weights)
Beispiel #6
0
class RLAgent(Agent):
    def __init__(self, load_path, policy_name, observation_space,
                 action_space):
        self._checkpoint_path = load_path
        self._policy_name = policy_name
        self._observation_space = observation_space
        self._action_space = action_space
        self._sess = None

        if isinstance(action_space, gym.spaces.Box):
            self.is_continuous = True
        elif isinstance(action_space, gym.spaces.Discrete):
            self.is_continuous = False
        else:
            raise TypeError("Unsupport action space")

        if self._sess:
            return

        self._prep = ModelCatalog.get_preprocessor_for_space(
            self._observation_space)
        self._sess = tf.compat.v1.Session(graph=tf.Graph())
        self._sess.__enter__()

        with tf.compat.v1.name_scope(self._policy_name):
            # obs_space need to be flattened before passed to PPOTFPolicy
            flat_obs_space = self._prep.observation_space
            self.policy = LoadPolicy(flat_obs_space, self._action_space, {})
            objs = pickle.load(open(self._checkpoint_path, "rb"))
            objs = pickle.loads(objs["worker"])
            state = objs["state"]
            weights = state[self._policy_name]
            self.policy.set_weights(weights)

    def act(self, obs):
        if isinstance(obs, list):
            # batch infer
            obs = [self._prep.transform(o) for o in obs]
            action = self.policy.compute_actions(obs, explore=False)[0]
        else:
            # single infer
            obs = self._prep.transform(obs)
            action = self.policy.compute_actions([obs], explore=False)[0][0]

        return action
Beispiel #7
0
    def __init__(self, load_path, algorithm, policy_name, observation_space,
                 action_space):
        self._checkpoint_path = load_path
        self._algorithm = algorithm
        self._policy_name = policy_name
        self._observation_space = observation_space
        self._action_space = action_space
        self._sess = None

        if isinstance(action_space, gym.spaces.Box):
            self.is_continuous = True
        elif isinstance(action_space, gym.spaces.Discrete):
            self.is_continuous = False
        else:
            raise TypeError("Unsupport action space")

        if self._sess:
            return

        if self._algorithm == "PPO":
            from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy as LoadPolicy
        elif self._algorithm in ["A2C", "A3C"]:
            from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy as LoadPolicy
        elif self._algorithm == "PG":
            from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy as LoadPolicy
        elif self._algorithm == "DQN":
            from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy as LoadPolicy
        else:
            raise TypeError("Unsupport algorithm")

        self._prep = ModelCatalog.get_preprocessor_for_space(
            self._observation_space)
        self._sess = tf.Session(graph=tf.Graph())
        self._sess.__enter__()

        import ray.rllib.agents.ppo as ppo
        config = ppo.DEFAULT_CONFIG.copy()
        config['num_workers'] = 0
        config["model"]["use_lstm"] = True

        with tf.name_scope(self._policy_name):
            # obs_space need to be flattened before passed to PPOTFPolicy
            flat_obs_space = self._prep.observation_space
            self.policy = LoadPolicy(flat_obs_space, self._action_space,
                                     config)
            objs = pickle.load(open(self._checkpoint_path, "rb"))
            objs = pickle.loads(objs["worker"])
            state = objs["state"]
            filters = objs["filters"]
            self.filters = filters[self._policy_name]
            weights = state[self._policy_name]
            self.policy.set_weights(weights)

            self.model = self.policy.model
            # print(self.model.summary())
            self.rnn_state = self.model.get_initial_state()
            self.rnn_state = [[self.rnn_state[0]], [self.rnn_state[1]]]
Beispiel #8
0
    def __init__(self, load_path, algorithm, policy_name, yaml_path):
        load_path = str(load_path)
        if algorithm == "ppo":
            from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy as LoadPolicy
        elif algorithm in "a2c":
            from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy as LoadPolicy
            from ray.rllib.agents.a3c import DEFAULT_CONFIG
        elif algorithm == "pg":
            from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy as LoadPolicy
        elif algorithm == "dqn":
            from ray.rllib.agents.dqn import DQNTFPolicy as LoadPolicy
        elif algorithm == "maac":
            from benchmark.agents.maac.tf_policy import CA2CTFPolicy as LoadPolicy
            from benchmark.agents.maac.tf_policy import DEFAULT_CONFIG
        elif algorithm == "maddpg":
            from benchmark.agents.maddpg.tf_policy import MADDPG2TFPolicy as LoadPolicy
            from benchmark.agents.maddpg.tf_policy import DEFAULT_CONFIG
        elif algorithm == "mfac":
            from benchmark.agents.mfac.tf_policy import MFACTFPolicy as LoadPolicy
            from benchmark.agents.mfac.tf_policy import DEFAULT_CONFIG
        elif algorithm == "networked_pg":
            from benchmark.agents.networked_pg.tf_policy import (
                NetworkedPG as LoadPolicy,
            )
            from benchmark.agents.networked_pg.tf_policy import (
                PG_DEFAULT_CONFIG as DEFAULT_CONFIG,
            )
        else:
            raise ValueError(f"Unsupported algorithm: {algorithm}")

        yaml_path = BASE_DIR / yaml_path
        load_path = BASE_DIR / f"log/results/run/{load_path}"

        config = load_config(yaml_path)
        observation_space = config["policy"][1]
        action_space = config["policy"][2]
        pconfig = DEFAULT_CONFIG

        pconfig["model"].update(config["policy"][-1].get("model", {}))
        pconfig["agent_id"] = policy_name

        self._prep = ModelCatalog.get_preprocessor_for_space(observation_space)
        self._sess = tf.Session(graph=tf.get_default_graph())

        with tf.name_scope(policy_name):
            # Observation space needs to be flattened before passed to the policy
            flat_obs_space = self._prep.observation_space
            policy = LoadPolicy(flat_obs_space, action_space, pconfig)
            self._sess.run(tf.global_variables_initializer())
            objs = pickle.load(open(load_path, "rb"))
            objs = pickle.loads(objs["worker"])
            state = objs["state"]
            weights = state[policy_name]
            policy.set_weights(weights)

        # for op in tf.get_default_graph().get_operations():
        #     print(str(op.name))

        # These tensor names were found by inspecting the trained model
        if algorithm == "ppo":
            # CRUCIAL FOR SAFETY:
            #   We use Tensor("split") instead of Tensor("add") to force
            #   PPO to be deterministic.
            self._input_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/observation:0"
            )
            self._output_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/split:0"
            )
        elif algorithm == "dqn":
            self._input_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/observations:0"
            )
            self._output_node = tf.argmax(
                self._sess.graph.get_tensor_by_name(
                    f"{policy_name}/value_out/BiasAdd:0"
                ),
                axis=1,
            )
        elif algorithm == "maac":
            self._input_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/policy-inputs:0"
            )
            self._output_node = tf.argmax(
                self._sess.graph.get_tensor_by_name(
                    f"{policy_name}/logits_out/BiasAdd:0"
                ),
                axis=1,
            )
        elif algorithm == "maddpg":
            self._input_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/obs_2:0"
            )
            self._output_node = tf.argmax(
                self._sess.graph.get_tensor_by_name(
                    f"{policy_name}/actor/AGENT_2_actor_RelaxedOneHotCategorical_1/sample/AGENT_2_actor_exp/forward/Exp:0"
                )
            )
        else:
            self._input_node = self._sess.graph.get_tensor_by_name(
                f"{policy_name}/observations:0"
            )
            self._output_node = tf.argmax(
                self._sess.graph.get_tensor_by_name(f"{policy_name}/fc_out/BiasAdd:0"),
                axis=1,
            )
Beispiel #9
0

def central_vf_stats(policy, train_batch, grads):
    # Report the explained variance of the central value function.
    return {
        "vf_explained_var":
        explained_variance(train_batch[Postprocessing.VALUE_TARGETS],
                           policy.central_value_out),
    }


CCPPO = PPOTFPolicy.with_updates(
    name="CCPPO",
    postprocess_fn=centralized_critic_postprocessing,
    loss_fn=loss_with_central_critic,
    before_loss_init=setup_mixins,
    grad_stats_fn=central_vf_stats,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        CentralizedValueMixin
    ])

CCTrainer = PPOTrainer.with_updates(name="CCPPOTrainer", default_policy=CCPPO)

if __name__ == "__main__":
    args = parser.parse_args()
    ModelCatalog.register_custom_model("cc_model", CentralizedCriticModel)
    tune.run(CCTrainer,
             stop={
                 "timesteps_total": args.stop,
                 "episode_reward_mean": 7.99,
             },
Beispiel #10
0
fc_with_mask_model_config = {
    "model": {
        "custom_model": "fc_with_mask",
        "custom_options": {}
    }
}

ppo_agent_default_config_with_mask = merge_dicts(DEFAULT_CONFIG,
                                                 fc_with_mask_model_config)

PPOTFPolicyWithMask = PPOTFPolicy.with_updates(
    name="PPOTFPolicyWithMask",
    get_default_config=lambda: ppo_agent_default_config_with_mask,
    extra_action_fetches_fn=vf_preds_and_logits_fetches_new,
    before_loss_init=setup_mixins,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        ValueNetworkMixin, AddMaskInfoMixinForPolicy
    ])


class AddMaskInfoMixin(object):
    def get_mask_info(self):
        return self.get_mask()

    def get_mask(self):
        return self.get_policy().get_mask()

    def set_mask(self, mask_dict):
        # Check the input is correct.
Beispiel #11
0
def setup_mixins_dice(policy, obs_space, action_space, config):
    setup_mixins(policy, obs_space, action_space, config)
    DiversityValueNetworkMixin.__init__(policy, obs_space, action_space,
                                        config)
    discrete = isinstance(action_space, gym.spaces.Discrete)
    ComputeDiversityMixin.__init__(policy, discrete)


def setup_late_mixins(policy, obs_space, action_space, config):
    if config[DELAY_UPDATE]:
        TargetNetworkMixin.__init__(policy, obs_space, action_space, config)


DiCEPolicy = PPOTFPolicy.with_updates(
    name="DiCEPolicy",
    get_default_config=lambda: dice_default_config,
    postprocess_fn=postprocess_dice,
    loss_fn=dice_loss,
    stats_fn=kl_and_loss_stats_modified,
    gradients_fn=dice_gradient,
    grad_stats_fn=grad_stats_fn,
    extra_action_fetches_fn=additional_fetches,
    before_loss_init=setup_mixins_dice,
    after_init=setup_late_mixins,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        ValueNetworkMixin, DiversityValueNetworkMixin, ComputeDiversityMixin,
        TargetNetworkMixin
    ])
Beispiel #12
0
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, postprocess_ppo_gae


def my_postprocess_ppo_gae(policy, sample_batch, *args, **kwargs):
    if sample_batch.get('infos') is not None:
        idx = [i for i, x in enumerate(sample_batch['infos']) if x['done']]
        if idx:
            idx.append(sample_batch.count)
            sbatch = sample_batch.slice(0, idx[0] + 1)
            sbatch['dones'][-1] = True
            batch = postprocess_ppo_gae(policy, sbatch, *args, **kwargs)
            for s, t in zip(idx[:-1], idx[1:]):
                sbatch = sample_batch.slice(s, t + 1)
                sbatch['dones'][-1] = True
                batch.concat(
                    postprocess_ppo_gae(policy, sbatch, *args, **kwargs))
            return batch
    return postprocess_ppo_gae(policy, sample_batch, *args, **kwargs)


MyPpoPolicy = PPOTFPolicy.with_updates(name="MyPpoTFPolicy",
                                       postprocess_fn=my_postprocess_ppo_gae)

MyPpoTrainer = PPOTrainer.with_updates(name="MyPpoTrainer",
                                       default_policy=MyPpoPolicy)
Beispiel #13
0
        sample_batch_size=config["sample_batch_size"],
        num_envs_per_worker=config["num_envs_per_worker"],
        train_batch_size=config["train_batch_size"],
        standardize_fields=["advantages"],
        shuffle_sequences=config["shuffle_sequences"])


def setup_mixins_modified(policy, obs_space, action_space, config):
    AddLossMixin.__init__(policy, config)
    setup_mixins(policy, obs_space, action_space, config)


ExtraLossPPOTFPolicy = PPOTFPolicy.with_updates(
    name="ExtraLossPPOTFPolicy",
    get_default_config=lambda: extra_loss_ppo_default_config,
    postprocess_fn=postprocess_ppo_gae_modified,
    stats_fn=kl_and_loss_stats_modified,
    loss_fn=extra_loss_ppo_loss,
    before_loss_init=setup_mixins_modified,
    mixins=mixin_list + [AddLossMixin])

ExtraLossPPOTrainer = PPOTrainer.with_updates(
    name="ExtraLossPPO",
    default_config=extra_loss_ppo_default_config,
    validate_config=validate_config_modified,
    default_policy=ExtraLossPPOTFPolicy,
    make_policy_optimizer=choose_policy_optimizer)

if __name__ == '__main__':
    from toolbox.marl.test_extra_loss import test_extra_loss_ppo_trainer1

    test_extra_loss_ppo_trainer1(True)
Beispiel #14
0
                self._alpha_val = 0.5
        else:
            if running_mean > 1.5 * self._novelty_target:
                self._alpha_val *= (1 - self._alpha_coefficient)
            elif running_mean < 0.5 * self._novelty_target:
                self._alpha_val = min(
                    (1 + self._alpha_coefficient) * self._alpha_val, 0.5)

        self._alpha.load(self._alpha_val, session=self.get_session())
        return self._alpha_val


DECEPolicy = PPOTFPolicy.with_updates(
    name="DECEPolicy",
    get_default_config=lambda: dece_default_config,
    postprocess_fn=postprocess_dece,
    loss_fn=loss_dece,
    stats_fn=kl_and_loss_stats_modified,
    gradients_fn=tnb_gradients,
    grad_stats_fn=grad_stats_fn,
    extra_action_fetches_fn=additional_fetches,
    before_loss_init=setup_mixins_dece,
    after_init=setup_late_mixins,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        ValueNetworkMixin, NoveltyValueNetworkMixin, ComputeNoveltyMixin,
        TargetNetworkMixin, ConstrainNoveltyMixin
    ],
    get_batch_divisibility_req=get_batch_divisibility_req,
)
Beispiel #15
0
def setup_mixins(policy, obs_space, action_space, config):
    ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
    KLCoeffMixin.__init__(policy, config)
    EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
                                  config["entropy_coeff_schedule"])
    warmup_steps = config["model"]["custom_options"].get(
        "warmup_steps", 100000)
    TransformerLearningRateSchedule.__init__(
        policy, config["model"]["custom_options"]["transformer"]["num_heads"],
        warmup_steps)


TTFPPOPolicy = PPOTFPolicy.with_updates(name="TTFPPOPolicy",
                                        before_loss_init=setup_mixins,
                                        mixins=[
                                            TransformerLearningRateSchedule,
                                            EntropyCoeffSchedule, KLCoeffMixin,
                                            ValueNetworkMixin
                                        ])

TTFPPOPolicyInfer = PPOTFPolicy.with_updates(name="TTFPPOPolicyInfer",
                                             before_loss_init=setup_mixins,
                                             mixins=[
                                                 LearningRateSchedule,
                                                 EntropyCoeffSchedule,
                                                 KLCoeffMixin,
                                                 ValueNetworkMixin
                                             ])

register_trainable(
    "TTFPPO",
Beispiel #16
0
            return self.get_session().run(
                fim_embedding,
                feed_dict={self._input_dict[SampleBatch.CUR_OBS]: ob})

        self.get_fim_embedding = get_fim_embedding


def before_loss_init(policy, obs_space, action_space, config):
    setup_mixins(policy, obs_space, action_space, config)
    FIMEmbeddingMixin.__init__(policy)


PPOFIMTFPolicy = PPOTFPolicy.with_updates(name="PPOFIMTFPolicy",
                                          before_loss_init=before_loss_init,
                                          mixins=[
                                              LearningRateSchedule,
                                              EntropyCoeffSchedule,
                                              KLCoeffMixin, ValueNetworkMixin,
                                              FIMEmbeddingMixin
                                          ])


def get_policy_class(config):
    if config.get("use_pytorch") is True:
        raise NotImplementedError()
    else:
        return PPOFIMTFPolicy


PPOFIMTrainer = PPOTrainer.with_updates(
    name="PPOFIM",
    default_policy=PPOFIMTFPolicy,