def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config: ModelConfigDict, name: str, num_transformer_units: int, attn_dim: int, num_heads: int, memory_tau: int, head_dim: int, ff_hidden_dim: int, init_gate_bias: float = 2.0): """Initializes a GTrXLNet. Args: num_transformer_units (int): The number of Transformer repeats to use (denoted L in [2]). attn_dim (int): The input and output dimensions of one Transformer unit. num_heads (int): The number of attention heads to use in parallel. Denoted as `H` in [3]. memory_tau (int): The number of timesteps to store in each transformer block's memory M (concat'd over time and fed into next transformer block as input). head_dim (int): The dimension of a single(!) head. Denoted as `d` in [3]. ff_hidden_dim (int): The dimension of the hidden layer within the position-wise MLP (after the multi-head attention block within one Transformer unit). This is the size of the first of the two layers within the PositionwiseFeedforward. The second layer always has size=`attn_dim`. init_gate_bias (float): Initial bias values for the GRU gates (two GRUs per Transformer unit, one after the MHA, one after the position-wise MLP). """ super().__init__(observation_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.num_transformer_units = num_transformer_units self.attn_dim = attn_dim self.num_heads = num_heads self.memory_tau = memory_tau self.head_dim = head_dim self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] # Constant (non-trainable) sinusoid rel pos encoding matrix. Phi = relative_position_embedding(self.max_seq_len + self.memory_tau, self.attn_dim) self.linear_layer = SlimFC(in_size=self.obs_dim, out_size=self.attn_dim) self.layers = [self.linear_layer] # 2) Create L Transformer blocks according to [2]. for i in range(self.num_transformer_units): # RelativeMultiHeadAttention part. MHA_layer = SkipConnection( RelativeMultiHeadAttention(in_dim=self.attn_dim, out_dim=self.attn_dim, num_heads=num_heads, head_dim=head_dim, rel_pos_encoder=Phi, input_layernorm=True, output_activation=nn.ReLU), fan_in_layer=GRUGate(self.attn_dim, init_gate_bias)) # Position-wise MultiLayerPerceptron part. E_layer = SkipConnection(nn.Sequential( torch.nn.LayerNorm(self.attn_dim), SlimFC(in_size=self.attn_dim, out_size=ff_hidden_dim, use_bias=False, activation_fn=nn.ReLU), SlimFC(in_size=ff_hidden_dim, out_size=self.attn_dim, use_bias=False, activation_fn=nn.ReLU)), fan_in_layer=GRUGate( self.attn_dim, init_gate_bias)) # Build a list of all layers in order. self.layers.extend([MHA_layer, E_layer]) # Postprocess GTrXL output with another hidden layer. self.logits = SlimFC(in_size=self.attn_dim, out_size=self.num_outputs, activation_fn=nn.ReLU) # Value function used by all RLlib Torch RL implementations. self._value_out = None self.values_out = SlimFC(in_size=self.attn_dim, out_size=1, activation_fn=None)
def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: Optional[int], model_config: ModelConfigDict, name: str, *, num_transformer_units: int = 1, attention_dim: int = 64, num_heads: int = 2, memory_inference: int = 50, memory_training: int = 50, head_dim: int = 32, position_wise_mlp_dim: int = 32, init_gru_gate_bias: float = 2.0): """Initializes a GTrXLNet. Args: num_transformer_units (int): The number of Transformer repeats to use (denoted L in [2]). attention_dim (int): The input and output dimensions of one Transformer unit. num_heads (int): The number of attention heads to use in parallel. Denoted as `H` in [3]. memory_inference (int): The number of timesteps to concat (time axis) and feed into the next transformer unit as inference input. The first transformer unit will receive this number of past observations (plus the current one), instead. memory_training (int): The number of timesteps to concat (time axis) and feed into the next transformer unit as training input (plus the actual input sequence of len=max_seq_len). The first transformer unit will receive this number of past observations (plus the input sequence), instead. head_dim (int): The dimension of a single(!) attention head within a multi-head attention unit. Denoted as `d` in [3]. position_wise_mlp_dim (int): The dimension of the hidden layer within the position-wise MLP (after the multi-head attention block within one Transformer unit). This is the size of the first of the two layers within the PositionwiseFeedforward. The second layer always has size=`attention_dim`. init_gru_gate_bias (float): Initial bias values for the GRU gates (two GRUs per Transformer unit, one after the MHA, one after the position-wise MLP). """ super().__init__(observation_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.num_transformer_units = num_transformer_units self.attention_dim = attention_dim self.num_heads = num_heads self.memory_inference = memory_inference self.memory_training = memory_training self.head_dim = head_dim self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] self.linear_layer = SlimFC(in_size=self.obs_dim, out_size=self.attention_dim) self.layers = [self.linear_layer] attention_layers = [] # 2) Create L Transformer blocks according to [2]. for i in range(self.num_transformer_units): # RelativeMultiHeadAttention part. MHA_layer = SkipConnection( RelativeMultiHeadAttention( in_dim=self.attention_dim, out_dim=self.attention_dim, num_heads=num_heads, head_dim=head_dim, input_layernorm=True, output_activation=nn.ReLU, ), fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias), ) # Position-wise MultiLayerPerceptron part. E_layer = SkipConnection( nn.Sequential( torch.nn.LayerNorm(self.attention_dim), SlimFC( in_size=self.attention_dim, out_size=position_wise_mlp_dim, use_bias=False, activation_fn=nn.ReLU, ), SlimFC( in_size=position_wise_mlp_dim, out_size=self.attention_dim, use_bias=False, activation_fn=nn.ReLU, ), ), fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias), ) # Build a list of all attanlayers in order. attention_layers.extend([MHA_layer, E_layer]) # Create a Sequential such that all parameters inside the attention # layers are automatically registered with this top-level model. self.attention_layers = nn.Sequential(*attention_layers) self.layers.extend(attention_layers) # Final layers if num_outputs not None. self.logits = None self.values_out = None # Last value output. self._value_out = None # Postprocess GTrXL output with another hidden layer. if self.num_outputs is not None: self.logits = SlimFC( in_size=self.attention_dim, out_size=self.num_outputs, activation_fn=nn.ReLU, ) # Value function used by all RLlib Torch RL implementations. self.values_out = SlimFC(in_size=self.attention_dim, out_size=1, activation_fn=None) else: self.num_outputs = self.attention_dim # Setup trajectory views (`memory-inference` x past memory outs). for i in range(self.num_transformer_units): space = Box(-1.0, 1.0, shape=(self.attention_dim, )) self.view_requirements["state_in_{}".format(i)] = ViewRequirement( "state_out_{}".format(i), shift="-{}:-1".format(self.memory_inference), # Repeat the incoming state every max-seq-len times. batch_repeat_value=self.max_seq_len, space=space, ) self.view_requirements["state_out_{}".format(i)] = ViewRequirement( space=space, used_for_training=False)