def forward(self, inputs: TensorType, memory: TensorType = None) -> TensorType: T = list(inputs.size())[1] # length of segment (time) H = self._num_heads # number of attention heads d = self._head_dim # attention head dimension # Add previous memory chunk (as const, w/o gradient) to input. # Tau (number of (prev) time slices in each memory chunk). Tau = list(memory.shape)[1] if memory is not None else 0 if memory is not None: memory.requires_grad_(False) inputs = torch.cat((memory, inputs), dim=1) # Apply the Layer-Norm. if self._input_layernorm is not None: inputs = self._input_layernorm(inputs) qkv = self._qkv_layer(inputs) queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) # Cut out Tau memory timesteps from query. queries = queries[:, -T:] queries = torch.reshape(queries, [-1, T, H, d]) keys = torch.reshape(keys, [-1, T + Tau, H, d]) values = torch.reshape(values, [-1, T + Tau, H, d]) R = self._pos_proj(self._rel_pos_encoder) R = torch.reshape(R, [T + Tau, H, d]) # b=batch # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) # h=head # d=head-dim (over which we will reduce-sum) score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R) score = score + self.rel_shift(pos_score) score = score / d**0.5 # causal mask of the same length as the sequence mask = sequence_mask( torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype) mask = mask[None, :, :, None] masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.) wmat = nn.functional.softmax(masked_score, dim=2) out = torch.einsum("bijh,bjhd->bihd", wmat, values) shape = list(out.shape)[:2] + [H * d] out = torch.reshape(out, shape) return self._linear_layer(out)
def forward_rnn(self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): # To make Attention work with current RLlib's ModelV2 API: # We assume `state` is the history of L recent observations (all # concatenated into one tensor) and append the current inputs to the # end and only keep the most recent (up to `max_seq_len`). This allows # us to deal with timestep-wise inference and full sequence training # within the same logic. state = [torch.from_numpy(item) for item in state] observations = state[0] memory = state[1:] inputs = torch.reshape(inputs, [1, -1, observations.shape[-1]]) observations = torch.cat((observations, inputs), axis=1)[:, -self.max_seq_len:] all_out = observations for i in range(len(self.layers)): # MHA layers which need memory passed in. if i % 2 == 1: all_out = self.layers[i](all_out, memory=memory[i // 2]) # Either linear layers or MultiLayerPerceptrons. else: all_out = self.layers[i](all_out) logits = self.logits(all_out) self._value_out = self.values_out(all_out) memory_outs = all_out[2:] # If memory_tau > max_seq_len -> overlap w/ previous `memory` input. if self.memory_tau > self.max_seq_len: memory_outs = [ torch.cat( [memory[i][:, -(self.memory_tau - self.max_seq_len):], m], axis=1) for i, m in enumerate(memory_outs) ] else: memory_outs = [m[:, -self.memory_tau:] for m in memory_outs] T = list(inputs.size())[1] # Length of input segment (time). # Postprocessing final output. logits = logits[:, -T:] self._value_out = self._value_out[:, -T:] return logits, [observations] + memory_outs