def forward(self, input, incremental_state=None, timestep=None): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input.size() #bsz, seq_len = torch.onnx.operators.shape_as_tensor(input) max_pos = self.padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): # recompute/expand embeddings if needed self.weights = SinusoidalPositionalEmbedding.get_embedding( max_pos, self.embedding_dim, self.padding_idx, ) self.weights = self.weights.type_as(self._float_tensor) if incremental_state is not None: # positions is the same for every token when decoding a single step pos = (timestep.int() + 1).long() if timestep is not None else seq_len if self.onnx_trace: return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1) return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) positions = make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace) # print(input[0]) # print(positions[0]) # sdf if self.onnx_trace: flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1]))) embeddings = torch._reshape_from_tensor(flat_embeddings, embedding_shape) return embeddings return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def reshape_from_tensor_shape(x, shape): return torch._reshape_from_tensor(x, shape)