def _embedding_fn(self, tokens: torch.LongTensor, positions: torch.LongTensor) -> torch.Tensor: token_device = tokens.get_device() #print("before self.word_embedder.embedding in _embedding_fn:", self.word_embedder.embedding) if token_device != self.word_embedder.embedding.get_device(): print("still diff") #self.word_embedder._embedding = nn.Parameter(self.word_embedder.embedding.to(token_device), requires_grad=True) #self.word_embedder._embedding = scatter(self.word_embedder._embedding, [0, 1]) print("after self.word_embedder.embedding in _embedding_fn:", self.word_embedder.embedding) word_embed = self.word_embedder(tokens) scale = self.config_model.hidden_dim**0.5 pos_embed = self.pos_embedder(positions) return word_embed * scale + pos_embed