def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): if tokens is None: raise RuntimeError("tokens is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding) logits = self.model(torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens)) return self.output_layer(logits)
def forward( self, texts: Optional[List[str]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): if tokens is None: raise RuntimeError("tokens is required") trimmed_tokens: List[List[str]] = [] if self.max_seq_len >= 0: for token in tokens: trimmed_tokens.append(token[0:self.max_seq_len]) else: trimmed_tokens = tokens seq_lens = make_sequence_lengths(trimmed_tokens) word_ids = self.vocab.lookup_indices_2d(trimmed_tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits)
def forward( self, texts: Optional[List[str]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, dense_feat: Optional[List[List[float]]] = None, ): if tokens is None: raise RuntimeError("tokens is required") seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) if dense_feat is not None: dense_feat = self.normalizer.normalize(dense_feat) else: raise RuntimeError("dense is required") logits = self.model( torch.tensor(word_ids), torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits)
def tensorize( self, texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[List[str]]]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: tokens_2d: List[List[int]] = [] seq_len_2d: List[int] = [] start_indices_2d: List[List[int]] = [] end_indices_2d: List[List[int]] = [] positions_2d: List[List[int]] = [] for idx in range(self.batch_size(texts, tokens)): numberized: Tuple[List[int], int, List[int], List[int], List[int]] = self.numberize( self.get_texts_by_index(texts, idx), self.get_tokens_by_index(tokens, idx), ) tokens_2d.append(numberized[0]) seq_len_2d.append(numberized[1]) start_indices_2d.append(numberized[2]) end_indices_2d.append(numberized[3]) positions_2d.append(numberized[4]) tokens, pad_mask = pad_2d_mask( tokens_2d, pad_value=self.vocab.pad_idx, seq_padding_control=self.seq_padding_control, max_seq_pad_len=self.max_seq_len, batch_padding_control=self.batch_padding_control, ) start_indices = torch.tensor( pad_2d( start_indices_2d, seq_lens=seq_len_2d, pad_idx=self.vocab.pad_idx, max_len=self.max_seq_len, ), dtype=torch.long, ) end_indices = torch.tensor( pad_2d( end_indices_2d, seq_lens=seq_len_2d, pad_idx=self.vocab.pad_idx, max_len=self.max_seq_len, ), dtype=torch.long, ) positions = torch.tensor( pad_2d( positions_2d, seq_lens=seq_len_2d, pad_idx=self.vocab.pad_idx, max_len=self.max_seq_len, ), dtype=torch.long, ) if self.device == "": return tokens, pad_mask, start_indices, end_indices, positions else: return ( tokens.to(self.device), pad_mask.to(self.device), start_indices.to(self.device), end_indices.to(self.device), positions.to(self.device), )
def forward(self, tokens: List[List[str]]): seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits)