예제 #1
0
    def __init__(self,
                 input_space: gym.spaces.Space,
                 action_space: gym.spaces.Space,
                 *,
                 name: str,
                 max_seq_len: int = 20,
                 num_transformer_units: int = 1,
                 attention_dim: int = 64,
                 num_heads: int = 2,
                 memory_inference: int = 50,
                 memory_training: int = 50,
                 head_dim: int = 32,
                 position_wise_mlp_dim: int = 32,
                 init_gru_gate_bias: float = 2.0):
        """Initializes a GTrXLNet instance.

        Args:
            num_transformer_units (int): The number of Transformer repeats to
                use (denoted L in [2]).
            attention_dim (int): The input and output dimensions of one
                Transformer unit.
            num_heads (int): The number of attention heads to use in parallel.
                Denoted as `H` in [3].
            memory_inference (int): The number of timesteps to concat (time
                axis) and feed into the next transformer unit as inference
                input. The first transformer unit will receive this number of
                past observations (plus the current one), instead.
            memory_training (int): The number of timesteps to concat (time
                axis) and feed into the next transformer unit as training
                input (plus the actual input sequence of len=max_seq_len).
                The first transformer unit will receive this number of
                past observations (plus the input sequence), instead.
            head_dim (int): The dimension of a single(!) attention head within
                a multi-head attention unit. Denoted as `d` in [3].
            position_wise_mlp_dim (int): The dimension of the hidden layer
                within the position-wise MLP (after the multi-head attention
                block within one Transformer unit). This is the size of the
                first of the two layers within the PositionwiseFeedforward. The
                second layer always has size=`attention_dim`.
            init_gru_gate_bias (float): Initial bias values for the GRU gates
                (two GRUs per Transformer unit, one after the MHA, one after
                the position-wise MLP).
        """

        super().__init__(name=name)

        self.num_transformer_units = num_transformer_units
        self.attention_dim = attention_dim
        self.num_heads = num_heads
        self.memory_inference = memory_inference
        self.memory_training = memory_training
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.obs_dim = input_space.shape[0]

        # Raw observation input (plus (None) time axis).
        input_layer = tf.keras.layers.Input(shape=(
            None,
            self.obs_dim,
        ),
                                            name="inputs")
        memory_ins = [
            tf.keras.layers.Input(shape=(
                None,
                self.attention_dim,
            ),
                                  dtype=tf.float32,
                                  name="memory_in_{}".format(i))
            for i in range(self.num_transformer_units)
        ]

        # Map observation dim to input/output transformer (attention) dim.
        E_out = tf.keras.layers.Dense(self.attention_dim)(input_layer)
        # Output, collected and concat'd to build the internal, tau-len
        # Memory units used for additional contextual information.
        memory_outs = [E_out]

        # 2) Create L Transformer blocks according to [2].
        for i in range(self.num_transformer_units):
            # RelativeMultiHeadAttention part.
            MHA_out = SkipConnection(
                RelativeMultiHeadAttention(out_dim=self.attention_dim,
                                           num_heads=num_heads,
                                           head_dim=head_dim,
                                           input_layernorm=True,
                                           output_activation=tf.nn.relu),
                fan_in_layer=GRUGate(init_gru_gate_bias),
                name="mha_{}".format(i + 1))(E_out, memory=memory_ins[i])
            # Position-wise MLP part.
            E_out = SkipConnection(tf.keras.Sequential(
                (tf.keras.layers.LayerNormalization(axis=-1),
                 PositionwiseFeedforward(out_dim=self.attention_dim,
                                         hidden_dim=position_wise_mlp_dim,
                                         output_activation=tf.nn.relu))),
                                   fan_in_layer=GRUGate(init_gru_gate_bias),
                                   name="pos_wise_mlp_{}".format(i +
                                                                 1))(MHA_out)
            # Output of position-wise MLP == E(l-1), which is concat'd
            # to the current Mem block (M(l-1)) to yield E~(l-1), which is then
            # used by the next transformer block.
            memory_outs.append(E_out)

        self._logits = None
        self._value_out = None

        self.trxl_model = tf.keras.Model(inputs=[input_layer] + memory_ins,
                                         outputs=[E_out] + memory_outs[:-1])

        self.view_requirements = {
            SampleBatch.OBS: ViewRequirement(space=input_space),
        }
        # Setup trajectory views (`memory-inference` x past memory outs).
        for i in range(self.num_transformer_units):
            space = Box(-1.0, 1.0, shape=(self.attention_dim, ))
            self.view_requirements["state_in_{}".format(i)] = \
                ViewRequirement(
                    "state_out_{}".format(i),
                    shift="-{}:-1".format(self.memory_inference),
                    # Repeat the incoming state every max-seq-len times.
                    batch_repeat_value=self.max_seq_len,
                    space=space)
            self.view_requirements["state_out_{}".format(i)] = \
                ViewRequirement(
                    space=space,
                    used_for_training=False)
예제 #2
0
    def __init__(self,
                 observation_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 num_transformer_units,
                 attn_dim,
                 num_heads,
                 memory_tau,
                 head_dim,
                 ff_hidden_dim,
                 init_gate_bias=2.0):
        """Initializes a GTrXLNet.

        Args:
            num_transformer_units (int): The number of Transformer repeats to
                use (denoted L in [2]).
            attn_dim (int): The input and output dimensions of one Transformer
                unit.
            num_heads (int): The number of attention heads to use in parallel.
                Denoted as `H` in [3].
            memory_tau (int): The number of timesteps to store in each
                transformer block's memory M (concat'd over time and fed into
                next transformer block as input).
            head_dim (int): The dimension of a single(!) head.
                Denoted as `d` in [3].
            ff_hidden_dim (int): The dimension of the hidden layer within
                the position-wise MLP (after the multi-head attention block
                within one Transformer unit). This is the size of the first
                of the two layers within the PositionwiseFeedforward. The
                second layer always has size=`attn_dim`.
            init_gate_bias (float): Initial bias values for the GRU gates (two
                GRUs per Transformer unit, one after the MHA, one after the
                position-wise MLP).
        """

        super().__init__(observation_space, action_space, num_outputs,
                         model_config, name)

        self.num_transformer_units = num_transformer_units
        self.attn_dim = attn_dim
        self.num_heads = num_heads
        self.memory_tau = memory_tau
        self.head_dim = head_dim
        self.max_seq_len = model_config["max_seq_len"]
        self.obs_dim = observation_space.shape[0]

        # Constant (non-trainable) sinusoid rel pos encoding matrix.
        Phi = relative_position_embedding(self.max_seq_len + self.memory_tau,
                                          self.attn_dim)

        # Raw observation input.
        input_layer = tf.keras.layers.Input(shape=(self.max_seq_len,
                                                   self.obs_dim),
                                            name="inputs")
        memory_ins = [
            tf.keras.layers.Input(shape=(self.memory_tau, self.attn_dim),
                                  dtype=tf.float32,
                                  name="memory_in_{}".format(i))
            for i in range(self.num_transformer_units)
        ]

        # Map observation dim to input/output transformer (attention) dim.
        E_out = tf.keras.layers.Dense(self.attn_dim)(input_layer)
        # Output, collected and concat'd to build the internal, tau-len
        # Memory units used for additional contextual information.
        memory_outs = [E_out]

        # 2) Create L Transformer blocks according to [2].
        for i in range(self.num_transformer_units):
            # RelativeMultiHeadAttention part.
            MHA_out = SkipConnection(
                RelativeMultiHeadAttention(out_dim=self.attn_dim,
                                           num_heads=num_heads,
                                           head_dim=head_dim,
                                           rel_pos_encoder=Phi,
                                           input_layernorm=True,
                                           output_activation=tf.nn.relu),
                fan_in_layer=GRUGate(init_gate_bias),
                name="mha_{}".format(i + 1))(E_out, memory=memory_ins[i])
            # Position-wise MLP part.
            E_out = SkipConnection(tf.keras.Sequential(
                (tf.keras.layers.LayerNormalization(axis=-1),
                 PositionwiseFeedforward(out_dim=self.attn_dim,
                                         hidden_dim=ff_hidden_dim,
                                         output_activation=tf.nn.relu))),
                                   fan_in_layer=GRUGate(init_gate_bias),
                                   name="pos_wise_mlp_{}".format(i +
                                                                 1))(MHA_out)
            # Output of position-wise MLP == E(l-1), which is concat'd
            # to the current Mem block (M(l-1)) to yield E~(l-1), which is then
            # used by the next transformer block.
            memory_outs.append(E_out)

        # Postprocess TrXL output with another hidden layer and compute values.
        logits = tf.keras.layers.Dense(self.num_outputs,
                                       activation=tf.keras.activations.linear,
                                       name="logits")(E_out)

        self._value_out = None
        values_out = tf.keras.layers.Dense(1, activation=None,
                                           name="values")(E_out)

        self.trxl_model = tf.keras.Model(inputs=[input_layer] + memory_ins,
                                         outputs=[logits, values_out] +
                                         memory_outs[:-1])

        self.register_variables(self.trxl_model.variables)
        self.trxl_model.summary()