class CustomTFRPGModel(TFModelV2): """Example of interpreting repeated observations.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): super().__init__(obs_space, action_space, num_outputs, model_config, name) self.model = TFFCNet(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables()) def forward(self, input_dict, state, seq_lens): # The unpacked input tensors, where M=MAX_PLAYERS, N=MAX_ITEMS: # { # 'items', <tf.Tensor shape=(?, M, N, 5)>, # 'location', <tf.Tensor shape=(?, M, 2)>, # 'status', <tf.Tensor shape=(?, M, 10)>, # } print("The unpacked input tensors:", input_dict["obs"]) print() print("Unbatched repeat dim", input_dict["obs"].unbatch_repeat_dim()) print() if tf.executing_eagerly(): print("Fully unbatched", input_dict["obs"].unbatch_all()) print() return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function()
class CentralizedCriticModel(TFModelV2): """Multi-agent model that implements a centralized VF.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(CentralizedCriticModel, self).__init__( obs_space, action_space, num_outputs, model_config, name) # Base of the model self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables()) # Central VF maps (obs, opp_ops, opp_act) -> vf_pred obs = tf.keras.layers.Input(shape=(6, ), name="obs") opp_obs = tf.keras.layers.Input(shape=(6, ), name="opp_obs") opp_act = tf.keras.layers.Input(shape=(2, ), name="opp_act") concat_obs = tf.keras.layers.Concatenate(axis=1)( [obs, opp_obs, opp_act]) central_vf_dense = tf.keras.layers.Dense( 16, activation=tf.nn.tanh, name="c_vf_dense")(concat_obs) central_vf_out = tf.keras.layers.Dense( 1, activation=None, name="c_vf_out")(central_vf_dense) self.central_vf = tf.keras.Model( inputs=[obs, opp_obs, opp_act], outputs=central_vf_out) self.register_variables(self.central_vf.variables) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def central_value_function(self, obs, opponent_obs, opponent_actions): return tf.reshape( self.central_vf( [obs, opponent_obs, tf.one_hot(opponent_actions, 2)]), [-1])
class CustomModel(TFModelV2): """Example of a custom model that just delegates to a fc-net.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(CustomModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables()) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function()
class FCModel(TFModelV2): '''Fully Connected Model''' def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(FCModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables()) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function()
class CentralizedCriticModel(TFModelV2): """Multi-agent model that implements a centralized VF.""" # TODO(@evinitsky) make this work with more than boxes def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(CentralizedCriticModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) # Base of the model self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables()) # Central VF maps (obs, opp_ops, opp_act) -> vf_pred self.max_num_agents = model_config['custom_options']['max_num_agents'] self.obs_space_shape = obs_space.shape[0] other_obs = tf.keras.layers.Input(shape=(obs_space.shape[0] * self.max_num_agents, ), name="opp_obs") central_vf_dense = tf.keras.layers.Dense( model_config['custom_options']['central_vf_size'], activation=tf.nn.tanh, name="c_vf_dense")(other_obs) central_vf_out = tf.keras.layers.Dense( 1, activation=None, name="c_vf_out")(central_vf_dense) self.central_vf = tf.keras.Model(inputs=[other_obs], outputs=central_vf_out) self.register_variables(self.central_vf.variables) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def central_value_function(self, obs, opponent_obs): return tf.reshape(self.central_vf([opponent_obs]), [-1]) def value_function(self): return self.model.value_function() # not used