def __init__( self, observation_space, action_space, num_outputs, model_config, name ): super().__init__( observation_space, action_space, num_outputs, model_config, name ) inputs = tf.keras.layers.Input(shape=observation_space.shape) self.fcnet = FullyConnectedNetwork( obs_space=self.obs_space, action_space=self.action_space, num_outputs=self.num_outputs, model_config=self.model_config, name="fc1", ) out, value_out = self.fcnet.base_model(inputs) def lambda_(x): eager_out = tf.py_function(self.forward_eager, [x], tf.float32) with tf1.control_dependencies([eager_out]): eager_out.set_shape(x.shape) return eager_out out = tf.keras.layers.Lambda(lambda_)(out) self.base_model = tf.keras.models.Model(inputs, [out, value_out])
def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(CustomModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables())
class FCMaskedActionsModelTF(DistributionalQTFModel, TFModelV2): def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw): super(FCMaskedActionsModelTF, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) true_obs_space = gym.spaces.MultiBinary(n=obs_space.shape[0] - action_space.n) self.action_embed_model = FullyConnectedNetwork( obs_space=true_obs_space, action_space=action_space, num_outputs=action_space.n, model_config=model_config, name=name + "action_model") self.register_variables(self.action_embed_model.variables()) def forward(self, input_dict, state, seq_lens): action_mask = input_dict["obs"]["action_mask"] # Compute the predicted action embedding raw_actions, _ = self.action_embed_model( {"obs": input_dict["obs"]["real_obs"]}) #inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min) logits = tf.where(tf.math.equal(action_mask, 1), raw_actions, tf.float32.min) return logits, state def value_function(self): return self.action_embed_model.value_function()
class ParametricActionsModel(DistributionalQTFModel): def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw): print("{} : [INFO] ParametricActionsModel {}, {}, {}, {}, {}".format( datetime.now(), action_space, obs_space, num_outputs, name, model_config)) super(ParametricActionsModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) # print("####### obs_space {}".format(obs_space)) # raise Exception("END") self.action_param_model = FullyConnectedNetwork( FLAT_OBSERVATION_SPACE, action_space, num_outputs, model_config, name + "_action_param") self.register_variables(self.action_param_model.variables()) def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. action_mask = input_dict["obs"]["action_mask"] # Compute the predicted action embedding action_param, _ = self.action_param_model( {"obs": input_dict["obs"]["state"]}) # Mask out invalid actions (use tf.float32.min for stability) inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min) return action_param + inf_mask, state def value_function(self): return self.action_param_model.value_function()
class CustomTFRPGModel(TFModelV2): """Example of interpreting repeated observations.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): super().__init__(obs_space, action_space, num_outputs, model_config, name) self.model = TFFCNet(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables()) def forward(self, input_dict, state, seq_lens): # The unpacked input tensors, where M=MAX_PLAYERS, N=MAX_ITEMS: # { # 'items', <tf.Tensor shape=(?, M, N, 5)>, # 'location', <tf.Tensor shape=(?, M, 2)>, # 'status', <tf.Tensor shape=(?, M, 10)>, # } print("The unpacked input tensors:", input_dict["obs"]) print() print("Unbatched repeat dim", input_dict["obs"].unbatch_repeat_dim()) print() if tf.executing_eagerly(): print("Fully unbatched", input_dict["obs"].unbatch_all()) print() return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function()
def __init__(self, obs_space, action_space, num_outputs, model_config, name): super().__init__(obs_space, action_space, num_outputs, model_config, name) self.fcnet = FullyConnectedNetwork(self.obs_space, self.action_space, num_outputs, model_config, name="fcnet")
def __init__(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(4, ), action_embed_size=2, **kw): super(ParametricActionsModelThatLearnsEmbeddings, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) action_ids_shifted = tf.constant(list(range(1, num_outputs + 1)), dtype=tf.float32) obs_cart = tf.keras.layers.Input(shape=true_obs_shape, name="obs_cart") valid_avail_actions_mask = tf.keras.layers.Input( shape=(num_outputs), name="valid_avail_actions_mask") self.pred_action_embed_model = FullyConnectedNetwork( Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size, model_config, name + "_pred_action_embed") # Compute the predicted action embedding pred_action_embed, _ = self.pred_action_embed_model({"obs": obs_cart}) _value_out = self.pred_action_embed_model.value_function() # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. intent_vector = tf.expand_dims(pred_action_embed, 1) valid_avail_actions = action_ids_shifted * valid_avail_actions_mask # Embedding for valid available actions which will be learned. # Embedding vector for 0 is an invalid embedding (a "dummy embedding"). valid_avail_actions_embed = tf.keras.layers.Embedding( input_dim=num_outputs + 1, output_dim=action_embed_size, name="action_embed_matrix")(valid_avail_actions) # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. action_logits = tf.reduce_sum(valid_avail_actions_embed * intent_vector, axis=2) # Mask out invalid actions (use tf.float32.min for stability) inf_mask = tf.maximum(tf.math.log(valid_avail_actions_mask), tf.float32.min) action_logits = action_logits + inf_mask self.param_actions_model = tf.keras.Model( inputs=[obs_cart, valid_avail_actions_mask], outputs=[action_logits, _value_out]) self.param_actions_model.summary()
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs): super().__init__(obs_space, action_space, num_outputs, model_config, name, **kwargs) low = np.asarray(model_config["custom_model_config"]["obs_space_low"]) high = np.asarray(model_config["custom_model_config"]["obs_space_high"]) self.policy = FullyConnectedNetwork( spaces.Box(low, high, shape=(len(low),)), action_space, num_outputs, model_config, "policy_network", )
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw): super(FCMaskedActionsModelTF, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) true_obs_space = gym.spaces.MultiBinary(n=obs_space.shape[0] - action_space.n) self.action_embed_model = FullyConnectedNetwork( obs_space=true_obs_space, action_space=action_space, num_outputs=action_space.n, model_config=model_config, name=name + "action_model") self.register_variables(self.action_embed_model.variables())
def __init__(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(4, ), action_embed_size=2, **kw): super(ParametricActionsModel, self).__init__( obs_space, action_space, num_outputs, model_config, name, **kw) self.action_embed_model = FullyConnectedNetwork( Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size, model_config, name + "_action_embed")
class CustomModel(TFModelV2): """Example of a keras custom model that just delegates to an fc-net.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(CustomModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function()
class CustomModel(TFModelV2): def __init__(self, obs_space, action_space, num_outputs, model_config, name): super().__init__(obs_space, action_space, num_outputs, model_config, name) self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables()) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function()
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs): orig_space = getattr(obs_space, "original_space", obs_space) assert isinstance(orig_space, Dict) and \ "action_mask" in orig_space.spaces and \ "observations" in orig_space.spaces super().__init__(obs_space, action_space, num_outputs, model_config, name) self.internal_model = FullyConnectedNetwork(orig_space["observations"], action_space, num_outputs, model_config, name + "_internal")
def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(YetAnotherCentralizedCriticModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.action_model = FullyConnectedNetwork( Box(low=0, high=1, shape=(6, )), # one-hot encoded Discrete(6) action_space, num_outputs, model_config, name + "_action") self.value_model = FullyConnectedNetwork(obs_space, action_space, 1, model_config, name + "_vf")
class ParametricActionsModel(DistributionalQTFModel): """Parametric action model that handles the dot product and masking. This assumes the outputs are logits for a single Categorical action dist. Getting this to work with a more complex output (e.g., if the action space is a tuple of several distributions) is also possible but left as an exercise to the reader. """ def __init__(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(4, ), action_embed_size=2, **kw): super(ParametricActionsModel, self).__init__( obs_space, action_space, num_outputs, model_config, name, **kw) self.action_embed_model = FullyConnectedNetwork( Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size, model_config, name + "_action_embed") self.register_variables(self.action_embed_model.variables()) def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. avail_actions = input_dict["obs"]["avail_actions"] action_mask = input_dict["obs"]["action_mask"] # Compute the predicted action embedding action_embed, _ = self.action_embed_model({ "obs": input_dict["obs"]["cart"] }) # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. intent_vector = tf.expand_dims(action_embed, 1) # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2) # Mask out invalid actions (use tf.float32.min for stability) inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min) return action_logits + inf_mask, state def value_function(self): return self.action_embed_model.value_function()
class ParametricActionsModel(DistributionalQTFModel): def __init__(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(2, ), **kw): super(ParametricActionsModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) self.action_value_model = FullyConnectedNetwork( Box(-1, 1, shape=true_obs_shape), action_space, num_outputs, model_config, name + "_action_values", ) self.register_variables(self.action_value_model.variables()) def forward(self, input_dict, state, seq_lens): action_mask = input_dict["obs"]["action_mask"] action_values, _ = self.action_value_model( {"obs": input_dict["obs"]["actual_obs"]}) inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min) return action_values + inf_mask, state
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw): super(OwnershipActionMaskingModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) self.true_obs_shape = model_config['custom_model_config'][ 'true_obs_shape'] self.action_embed_size = model_config['custom_model_config'][ 'action_embed_size'] self.action_embed_model = FullyConnectedNetwork( self.true_obs_shape, action_space, self.action_embed_size, model_config, name + "_action_embed") # Box(-1, 0, shape=true_obs_shape) self.register_variables(self.action_embed_model.variables())
class ActionMaskModel(TFModelV2): def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs): super().__init__(obs_space, action_space, num_outputs, model_config, name, **kwargs) low = np.asarray(model_config["custom_model_config"]["obs_space_low"]) high = np.asarray(model_config["custom_model_config"]["obs_space_high"]) self.policy = FullyConnectedNetwork( spaces.Box(low, high, shape=(len(low),)), action_space, num_outputs, model_config, "policy_network", ) def forward(self, input_dict, state, seq_lens): obs = input_dict["obs"]["real_obs"] action_mask = input_dict["obs"]["action_mask"] action_logits, _ = self.policy({"obs": obs}, state, seq_lens) if self.num_outputs == 1: return action_logits, state # Mask out invalid actions (use tf.float32.min for stability) inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min) return action_logits + inf_mask, state def value_function(self): return self.policy.value_function()
def __init__(self, obs_space, action_space, num_outputs, model_config, name): super().__init__(obs_space, action_space, num_outputs, model_config, name) self.model = TFFCNet(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables())
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw): print("{} : [INFO] ParametricActionsModel {}, {}, {}, {}, {}".format( datetime.now(), action_space, obs_space, num_outputs, name, model_config)) super(ParametricActionsModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) # print("####### obs_space {}".format(obs_space)) # raise Exception("END") self.action_param_model = FullyConnectedNetwork( FLAT_OBSERVATION_SPACE, action_space, num_outputs, model_config, name + "_action_param") self.register_variables(self.action_param_model.variables())
class OwnershipActionMaskingModel(FullyConnectedNetwork): """ Parametric action model that handles the dot product and masking. This assumes the outputs are logits for a single Categorical action dist. """ def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw): super(OwnershipActionMaskingModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) self.true_obs_shape = model_config['custom_model_config'][ 'true_obs_shape'] self.action_embed_size = model_config['custom_model_config'][ 'action_embed_size'] self.action_embed_model = FullyConnectedNetwork( self.true_obs_shape, action_space, self.action_embed_size, model_config, name + "_action_embed") # Box(-1, 0, shape=true_obs_shape) self.register_variables(self.action_embed_model.variables()) def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. avail_actions = input_dict["obs"]["avail_actions"] action_mask = input_dict["obs"]["action_mask"] # Compute the predicted action embedding action_embed, _ = self.action_embed_model( {"obs": input_dict["obs"]["obs"]}) # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. intent_vector = tf.expand_dims(action_embed, 1) # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS]. action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=1) # Mask out invalid actions (use tf.float32.min for stability) inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min) return action_logits + inf_mask, state def value_function(self): return self.action_embed_model.value_function()
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs): orig_space = getattr(obs_space, "original_space", obs_space) assert isinstance(orig_space, Dict) and \ "action_mask" in orig_space.spaces and \ "observations" in orig_space.spaces super().__init__(obs_space, action_space, num_outputs, model_config, name) self.internal_model = FullyConnectedNetwork(orig_space["observations"], action_space, num_outputs, model_config, name + "_internal") # disable action masking --> will likely lead to invalid actions self.no_masking = model_config["custom_model_config"].get( "no_masking", False)
def __init__(self, obs_space, action_space, num_outputs, model_config, name, true_obs_shape=(2, ), **kw): super(ParametricActionsModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) self.action_value_model = FullyConnectedNetwork( Box(-1, 1, shape=true_obs_shape), action_space, num_outputs, model_config, name + "_action_values", ) self.register_variables(self.action_value_model.variables())
class CustomLossModel(TFModelV2): """Custom model that adds an imitation loss on top of the policy loss.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): super().__init__(obs_space, action_space, num_outputs, model_config, name) self.fcnet = FullyConnectedNetwork(self.obs_space, self.action_space, num_outputs, model_config, name="fcnet") @override(ModelV2) def forward(self, input_dict, state, seq_lens): # Delegate to our FCNet. return self.fcnet(input_dict, state, seq_lens) @override(ModelV2) def value_function(self): # Delegate to our FCNet. return self.fcnet.value_function() @override(ModelV2) def custom_loss(self, policy_loss, loss_inputs): # Create a new input reader per worker. reader = JsonReader( self.model_config["custom_model_config"]["input_files"]) input_ops = reader.tf_input_ops() # Define a secondary loss by building a graph copy with weight sharing. obs = restore_original_dimensions( tf.cast(input_ops["obs"], tf.float32), self.obs_space) logits, _ = self.forward({"obs": obs}, [], None) # You can also add self-supervised losses easily by referencing tensors # created during _build_layers_v2(). For example, an autoencoder-style # loss can be added as follows: # ae_loss = squared_diff( # loss_inputs["obs"], Decoder(self.fcnet.last_layer)) print("FYI: You can also use these tensors: {}, ".format(loss_inputs)) # Compute the IL loss. action_dist = Categorical(logits, self.model_config) self.policy_loss = policy_loss self.imitation_loss = tf.reduce_mean( -action_dist.logp(input_ops["actions"])) return policy_loss + 10 * self.imitation_loss def metrics(self): return { "policy_loss": self.policy_loss, "imitation_loss": self.imitation_loss, }
class CentralizedCriticModel(TFModelV2): """Multi-agent model that implements a centralized value function.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(CentralizedCriticModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) # Base of the model self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) # Central VF maps (obs, opp_obs, opp_act) -> vf_pred obs = tf.keras.layers.Input(shape=(6, ), name="obs") opp_obs = tf.keras.layers.Input(shape=(6, ), name="opp_obs") opp_act = tf.keras.layers.Input(shape=(2, ), name="opp_act") concat_obs = tf.keras.layers.Concatenate(axis=1)( [obs, opp_obs, opp_act]) central_vf_dense = tf.keras.layers.Dense(16, activation=tf.nn.tanh, name="c_vf_dense")(concat_obs) central_vf_out = tf.keras.layers.Dense( 1, activation=None, name="c_vf_out")(central_vf_dense) self.central_vf = tf.keras.Model(inputs=[obs, opp_obs, opp_act], outputs=central_vf_out) @override(ModelV2) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def central_value_function(self, obs, opponent_obs, opponent_actions): return tf.reshape( self.central_vf([ obs, opponent_obs, tf.one_hot(tf.cast(opponent_actions, tf.int32), 2) ]), [-1], ) @override(ModelV2) def value_function(self): return self.model.value_function() # not used
def __init__(self, obs_space, action_space, num_outputs, model_config, name): super().__init__(obs_space, action_space, num_outputs, model_config, name) input_ = tf.keras.layers.Input(shape=(3, )) output = tf.keras.layers.Dense(2)(input_) # A keras model inside. self.keras_model = tf.keras.models.Model([input_], [output]) # A RLlib FullyConnectedNetwork (tf) inside (which is also a keras # Model). self.fc_net = FullyConnectedNetwork(obs_space, action_space, 3, {}, "fc1")
class CustomModel(TFModelV2, selection_DelayedImpactEnv): """Example of a keras custom model that just delegates to an fc-net.""" obs_space = selection_DelayedImpactEnv.observation_space action_space = selection_DelayedImpactEnv.action_space # num_outputs=169 model_config = {} name = 'My_model' def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(CustomModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) self.register_variables(self.model.variables()) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function()
class YetAnotherCentralizedCriticModel(TFModelV2): """Multi-agent model that implements a centralized value function. It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the former of which can be used for computing actions (i.e., decentralized execution), and the latter for optimization (i.e., centralized learning). This model has two parts: - An action model that looks at just 'own_obs' to compute actions - A value model that also looks at the 'opponent_obs' / 'opponent_action' to compute the value (it does this by using the 'obs_flat' tensor). """ def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(YetAnotherCentralizedCriticModel, self).__init__( obs_space, action_space, num_outputs, model_config, name) self.action_model = FullyConnectedNetwork( Box(low=0, high=1, shape=(6, )), # one-hot encoded Discrete(6) action_space, num_outputs, model_config, name + "_action") self.register_variables(self.action_model.variables()) self.value_model = FullyConnectedNetwork(obs_space, action_space, 1, model_config, name + "_vf") self.register_variables(self.value_model.variables()) def forward(self, input_dict, state, seq_lens): self._value_out, _ = self.value_model({ "obs": input_dict["obs_flat"] }, state, seq_lens) return self.action_model({ "obs": input_dict["obs"]["own_obs"] }, state, seq_lens) def value_function(self): return tf.reshape(self._value_out, [-1])
class EagerModel(TFModelV2): """Example of using embedded eager execution in a custom model. This shows how to use tf.py_function() to execute a snippet of TF code in eager mode. Here the `self.forward_eager` method just prints out the intermediate tensor for debug purposes, but you can in general perform any TF eager operation in tf.py_function(). """ def __init__( self, observation_space, action_space, num_outputs, model_config, name ): super().__init__( observation_space, action_space, num_outputs, model_config, name ) inputs = tf.keras.layers.Input(shape=observation_space.shape) self.fcnet = FullyConnectedNetwork( obs_space=self.obs_space, action_space=self.action_space, num_outputs=self.num_outputs, model_config=self.model_config, name="fc1", ) out, value_out = self.fcnet.base_model(inputs) def lambda_(x): eager_out = tf.py_function(self.forward_eager, [x], tf.float32) with tf1.control_dependencies([eager_out]): eager_out.set_shape(x.shape) return eager_out out = tf.keras.layers.Lambda(lambda_)(out) self.base_model = tf.keras.models.Model(inputs, [out, value_out]) @override(ModelV2) def forward(self, input_dict, state, seq_lens): out, self._value_out = self.base_model(input_dict["obs"], state, seq_lens) return out, [] @override(ModelV2) def value_function(self): return tf.reshape(self._value_out, [-1]) def forward_eager(self, feature_layer): assert tf.executing_eagerly() if random.random() > 0.99: print( "Eagerly printing the feature layer mean value", tf.reduce_mean(feature_layer), ) return feature_layer
class ActionMaskModel(TFModelV2): """Model that handles simple discrete action masking. This assumes the outputs are logits for a single Categorical action dist. Getting this to work with a more complex output (e.g., if the action space is a tuple of several distributions) is also possible but left as an exercise to the reader. """ def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs): orig_space = getattr(obs_space, "original_space", obs_space) assert (isinstance(orig_space, Dict) and "action_mask" in orig_space.spaces and "observations" in orig_space.spaces) super().__init__(obs_space, action_space, num_outputs, model_config, name) self.internal_model = FullyConnectedNetwork( orig_space["observations"], action_space, num_outputs, model_config, name + "_internal", ) # disable action masking --> will likely lead to invalid actions self.no_masking = model_config["custom_model_config"].get( "no_masking", False) def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. action_mask = input_dict["obs"]["action_mask"] # Compute the unmasked logits. logits, _ = self.internal_model( {"obs": input_dict["obs"]["observations"]}) # If action masking is disabled, directly return unmasked logits if self.no_masking: return logits, state # Convert action_mask into a [0.0 || -inf]-type mask. inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min) masked_logits = logits + inf_mask # Return masked logits. return masked_logits, state def value_function(self): return self.internal_model.value_function()