def test_make_bytes_input(self): s1 = "I want some coffee today" s2 = "Turn it up" max_char_length = 5 batch = [s1.split(), s2.split()] bytes, seq_lens = make_byte_inputs(batch, max_char_length) def to_bytes(word, pad_to): return list(word.encode()) + [0] * (pad_to - len(word)) expected_bytes = [ [ to_bytes("I", 5), to_bytes("want", 5), to_bytes("some", 5), to_bytes("coffe", 5), to_bytes("today", 5), ], [ to_bytes("Turn", 5), to_bytes("it", 5), to_bytes("up", 5), to_bytes("", 5), to_bytes("", 5), ], ] expected_seq_lens = [5, 3] self.assertIsInstance(bytes, torch.LongTensor) self.assertIsInstance(seq_lens, torch.LongTensor) self.assertEqual(bytes.tolist(), expected_bytes) self.assertEqual(seq_lens.tolist(), expected_seq_lens)
def forward(self, tokens: List[List[str]]): seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding) logits = self.model(torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens)) return self.output_layer(logits)
def forward(self, tokens: List[List[str]], dense_feat: List[List[float]]): seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding) dense_feat = self.normalizer.normalize(dense_feat) logits = self.model( torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits)