class CustomTFRPGModel(TFModelV2):
    """Example of interpreting repeated observations."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)
        self.model = TFFCNet(obs_space, action_space, num_outputs,
                             model_config, name)
        self.register_variables(self.model.variables())

    def forward(self, input_dict, state, seq_lens):
        # The unpacked input tensors, where M=MAX_PLAYERS, N=MAX_ITEMS:
        # {
        #   'items', <tf.Tensor shape=(?, M, N, 5)>,
        #   'location', <tf.Tensor shape=(?, M, 2)>,
        #   'status', <tf.Tensor shape=(?, M, 10)>,
        # }
        print("The unpacked input tensors:", input_dict["obs"])
        print()
        print("Unbatched repeat dim", input_dict["obs"].unbatch_repeat_dim())
        print()
        if tf.executing_eagerly():
            print("Fully unbatched", input_dict["obs"].unbatch_all())
            print()
        return self.model.forward(input_dict, state, seq_lens)

    def value_function(self):
        return self.model.value_function()
Пример #2
0
class CentralizedCriticModel(TFModelV2):
    """Multi-agent model that implements a centralized VF."""

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(CentralizedCriticModel, self).__init__(
            obs_space, action_space, num_outputs, model_config, name)
        # Base of the model
        self.model = FullyConnectedNetwork(obs_space, action_space,
                                           num_outputs, model_config, name)
        self.register_variables(self.model.variables())

        # Central VF maps (obs, opp_ops, opp_act) -> vf_pred
        obs = tf.keras.layers.Input(shape=(6, ), name="obs")
        opp_obs = tf.keras.layers.Input(shape=(6, ), name="opp_obs")
        opp_act = tf.keras.layers.Input(shape=(2, ), name="opp_act")
        concat_obs = tf.keras.layers.Concatenate(axis=1)(
            [obs, opp_obs, opp_act])
        central_vf_dense = tf.keras.layers.Dense(
            16, activation=tf.nn.tanh, name="c_vf_dense")(concat_obs)
        central_vf_out = tf.keras.layers.Dense(
            1, activation=None, name="c_vf_out")(central_vf_dense)
        self.central_vf = tf.keras.Model(
            inputs=[obs, opp_obs, opp_act], outputs=central_vf_out)
        self.register_variables(self.central_vf.variables)

    def forward(self, input_dict, state, seq_lens):
        return self.model.forward(input_dict, state, seq_lens)

    def central_value_function(self, obs, opponent_obs, opponent_actions):
        return tf.reshape(
            self.central_vf(
                [obs, opponent_obs,
                 tf.one_hot(opponent_actions, 2)]), [-1])
Пример #3
0
class CustomModel(TFModelV2):
    """Example of a custom model that just delegates to a fc-net."""
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(CustomModel, self).__init__(obs_space, action_space, num_outputs,
                                          model_config, name)
        self.model = FullyConnectedNetwork(obs_space, action_space,
                                           num_outputs, model_config, name)
        self.register_variables(self.model.variables())

    def forward(self, input_dict, state, seq_lens):
        return self.model.forward(input_dict, state, seq_lens)

    def value_function(self):
        return self.model.value_function()
Пример #4
0
class FCModel(TFModelV2):
    '''Fully Connected Model'''
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(FCModel, self).__init__(obs_space, action_space, num_outputs,
                                      model_config, name)
        self.model = FullyConnectedNetwork(obs_space, action_space,
                                           num_outputs, model_config, name)
        self.register_variables(self.model.variables())

    def forward(self, input_dict, state, seq_lens):
        return self.model.forward(input_dict, state, seq_lens)

    def value_function(self):
        return self.model.value_function()
class CentralizedCriticModel(TFModelV2):
    """Multi-agent model that implements a centralized VF."""

    # TODO(@evinitsky) make this work with more than boxes

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super(CentralizedCriticModel,
              self).__init__(obs_space, action_space, num_outputs,
                             model_config, name)
        # Base of the model
        self.model = FullyConnectedNetwork(obs_space, action_space,
                                           num_outputs, model_config, name)
        self.register_variables(self.model.variables())

        # Central VF maps (obs, opp_ops, opp_act) -> vf_pred
        self.max_num_agents = model_config['custom_options']['max_num_agents']
        self.obs_space_shape = obs_space.shape[0]
        other_obs = tf.keras.layers.Input(shape=(obs_space.shape[0] *
                                                 self.max_num_agents, ),
                                          name="opp_obs")
        central_vf_dense = tf.keras.layers.Dense(
            model_config['custom_options']['central_vf_size'],
            activation=tf.nn.tanh,
            name="c_vf_dense")(other_obs)
        central_vf_out = tf.keras.layers.Dense(
            1, activation=None, name="c_vf_out")(central_vf_dense)
        self.central_vf = tf.keras.Model(inputs=[other_obs],
                                         outputs=central_vf_out)
        self.register_variables(self.central_vf.variables)

    def forward(self, input_dict, state, seq_lens):
        return self.model.forward(input_dict, state, seq_lens)

    def central_value_function(self, obs, opponent_obs):
        return tf.reshape(self.central_vf([opponent_obs]), [-1])

    def value_function(self):
        return self.model.value_function()  # not used