Пример #1
0
    def forward(self, inputs: TensorType,
                memory: TensorType = None) -> TensorType:
        T = list(inputs.size())[1]  # length of segment (time)
        H = self._num_heads  # number of attention heads
        d = self._head_dim  # attention head dimension

        # Add previous memory chunk (as const, w/o gradient) to input.
        # Tau (number of (prev) time slices in each memory chunk).
        Tau = list(memory.shape)[1] if memory is not None else 0
        if memory is not None:
            memory.requires_grad_(False)
            inputs = torch.cat((memory, inputs), dim=1)

        # Apply the Layer-Norm.
        if self._input_layernorm is not None:
            inputs = self._input_layernorm(inputs)

        qkv = self._qkv_layer(inputs)

        queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
        # Cut out Tau memory timesteps from query.
        queries = queries[:, -T:]

        queries = torch.reshape(queries, [-1, T, H, d])
        keys = torch.reshape(keys, [-1, T + Tau, H, d])
        values = torch.reshape(values, [-1, T + Tau, H, d])

        R = self._pos_proj(self._rel_pos_encoder)
        R = torch.reshape(R, [T + Tau, H, d])

        # b=batch
        # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
        # h=head
        # d=head-dim (over which we will reduce-sum)
        score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
        pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R)
        score = score + self.rel_shift(pos_score)
        score = score / d**0.5

        # causal mask of the same length as the sequence
        mask = sequence_mask(
            torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype)
        mask = mask[None, :, :, None]

        masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.)
        wmat = nn.functional.softmax(masked_score, dim=2)

        out = torch.einsum("bijh,bjhd->bihd", wmat, values)
        shape = list(out.shape)[:2] + [H * d]
        out = torch.reshape(out, shape)

        return self._linear_layer(out)
Пример #2
0
    def forward_rnn(self, inputs: TensorType, state: List[TensorType],
                    seq_lens: TensorType) -> (TensorType, List[TensorType]):
        # To make Attention work with current RLlib's ModelV2 API:
        # We assume `state` is the history of L recent observations (all
        # concatenated into one tensor) and append the current inputs to the
        # end and only keep the most recent (up to `max_seq_len`). This allows
        # us to deal with timestep-wise inference and full sequence training
        # within the same logic.
        state = [torch.from_numpy(item) for item in state]
        observations = state[0]
        memory = state[1:]

        inputs = torch.reshape(inputs, [1, -1, observations.shape[-1]])
        observations = torch.cat((observations, inputs),
                                 axis=1)[:, -self.max_seq_len:]

        all_out = observations
        for i in range(len(self.layers)):
            # MHA layers which need memory passed in.
            if i % 2 == 1:
                all_out = self.layers[i](all_out, memory=memory[i // 2])
            # Either linear layers or MultiLayerPerceptrons.
            else:
                all_out = self.layers[i](all_out)

        logits = self.logits(all_out)
        self._value_out = self.values_out(all_out)

        memory_outs = all_out[2:]
        # If memory_tau > max_seq_len -> overlap w/ previous `memory` input.
        if self.memory_tau > self.max_seq_len:
            memory_outs = [
                torch.cat(
                    [memory[i][:, -(self.memory_tau - self.max_seq_len):], m],
                    axis=1) for i, m in enumerate(memory_outs)
            ]
        else:
            memory_outs = [m[:, -self.memory_tau:] for m in memory_outs]

        T = list(inputs.size())[1]  # Length of input segment (time).

        # Postprocessing final output.
        logits = logits[:, -T:]
        self._value_out = self._value_out[:, -T:]

        return logits, [observations] + memory_outs