def infer(self, test_sentence): # read raw data to list test_sentence = self.convert(test_sentence) print(test_sentence) lines_raw = [test_sentence] lines_prep = [self.preprocess(test_sentence)] # prepare dataset print("Reading test data...") test = Seq2SeqDataset.from_list(lines_prep) test.src_field.vocab = self.src_vocab # prepare iterator test_iterator = BucketIterator(dataset=test, batch_size=1, train=False, sort=False, sort_within_batch=False, shuffle=False, device=device) # predict with torch.no_grad(): for i, batch in enumerate(test_iterator): # forward through model _, _, output = self.model(batch, has_targets=False, mask_softmax=1.0, teacher_forcing=1.0) # get top-1 predicted_values, predicted_indices = torch.max(output, dim=-1) # convert predicted vocab indices to an actual sentence predicted_seq = [ self.tgt_vocab.itos[c] for c in predicted_indices.squeeze(0).tolist() ] # output is log_softmax so do exp() predicted_values = predicted_values.exp() # convert to list predicted_values_ = predicted_values.squeeze(0).tolist() # beam search predicted_seq = self.beam_lm(''.join(predicted_seq[1:-1]), predicted_values_[1:-1], lines_raw[i]) # match case and punctuations predicted_seq = self.match_case(predicted_seq, lines_raw[i]) # do some post-processing to match submission output print("{} {}".format(i, predicted_seq)) return predicted_seq
def predict(self, test_path, test_cleaned_path, out_path): # read raw data to list lines_id = [] lines_raw = [] lines_cleaned = [] lines_prep = [] with open(test_path, 'r') as f, open(test_cleaned_path, 'r') as fc: for line in f: line_id = line[:3] line_seq = line[4:] lines_id.append(line_id) lines_raw.append(line_seq) lines_prep.append(self.preprocess(line_seq)) for line in fc: lines_cleaned.append(line[4:]) # prepare dataset print("Reading test data...") test = Seq2SeqDataset.from_list(lines_prep) test.src_field.vocab = self.src_vocab # prepare iterator test_iterator = BucketIterator(dataset=test, batch_size=1, train=False, sort=False, sort_within_batch=False, shuffle=False, device=device) # predict with open(out_path, 'w') as writer: with torch.no_grad(): for i, batch in enumerate(test_iterator): # forward through model _, _, output = self.model(batch, has_targets=False, mask_softmax=1.0, teacher_forcing=1.0) print(output.shape) # get top-1 predicted_values, predicted_indices = torch.max(output, dim=-1) print(predicted_values.shape) print(predicted_indices.shape) # convert predicted vocab indices to an actual sentence predicted_seq = [ self.tgt_vocab.itos[c] for c in predicted_indices.squeeze(0).tolist() ] # print('predicted_seq') # print(predicted_seq) # output is log_softmax so do exp() predicted_values = predicted_values.exp() # print('predicted_values') # print(predicted_values) # convert to list predicted_values_ = predicted_values.squeeze(0).tolist() # beam search predicted_seq = self.beam_lm(''.join(predicted_seq[1:-1]), predicted_values_[1:-1], lines_raw[i]) # match case and punctuations predicted_seq = self.match_case(predicted_seq, lines_raw[i]) # do some post-processing to match submission output predicted_seq = self.match_output(predicted_seq, lines_cleaned[i]) print("{} {}".format(i, predicted_seq)) # write to file with line_id writer.write(lines_id[i] + ',' + predicted_seq + '\n')