Beispiel #1
0
    def infer(self, test_sentence):
        # read raw data to list
        test_sentence = self.convert(test_sentence)
        print(test_sentence)
        lines_raw = [test_sentence]
        lines_prep = [self.preprocess(test_sentence)]

        # prepare dataset
        print("Reading test data...")
        test = Seq2SeqDataset.from_list(lines_prep)
        test.src_field.vocab = self.src_vocab

        # prepare iterator
        test_iterator = BucketIterator(dataset=test,
                                       batch_size=1,
                                       train=False,
                                       sort=False,
                                       sort_within_batch=False,
                                       shuffle=False,
                                       device=device)
        # predict
        with torch.no_grad():
            for i, batch in enumerate(test_iterator):
                # forward through model
                _, _, output = self.model(batch,
                                          has_targets=False,
                                          mask_softmax=1.0,
                                          teacher_forcing=1.0)
                # get top-1
                predicted_values, predicted_indices = torch.max(output, dim=-1)

                # convert predicted vocab indices to an actual sentence
                predicted_seq = [
                    self.tgt_vocab.itos[c]
                    for c in predicted_indices.squeeze(0).tolist()
                ]

                # output is log_softmax so do exp()
                predicted_values = predicted_values.exp()

                # convert to list
                predicted_values_ = predicted_values.squeeze(0).tolist()

                # beam search
                predicted_seq = self.beam_lm(''.join(predicted_seq[1:-1]),
                                             predicted_values_[1:-1],
                                             lines_raw[i])

                # match case and punctuations
                predicted_seq = self.match_case(predicted_seq, lines_raw[i])

                # do some post-processing to match submission output
                print("{} {}".format(i, predicted_seq))
        return predicted_seq
Beispiel #2
0
    def predict(self, test_path, test_cleaned_path, out_path):
        # read raw data to list
        lines_id = []
        lines_raw = []
        lines_cleaned = []
        lines_prep = []
        with open(test_path, 'r') as f, open(test_cleaned_path, 'r') as fc:
            for line in f:
                line_id = line[:3]
                line_seq = line[4:]
                lines_id.append(line_id)
                lines_raw.append(line_seq)
                lines_prep.append(self.preprocess(line_seq))
            for line in fc:
                lines_cleaned.append(line[4:])

        # prepare dataset
        print("Reading test data...")
        test = Seq2SeqDataset.from_list(lines_prep)
        test.src_field.vocab = self.src_vocab

        # prepare iterator
        test_iterator = BucketIterator(dataset=test,
                                       batch_size=1,
                                       train=False,
                                       sort=False,
                                       sort_within_batch=False,
                                       shuffle=False,
                                       device=device)

        # predict
        with open(out_path, 'w') as writer:
            with torch.no_grad():
                for i, batch in enumerate(test_iterator):
                    # forward through model
                    _, _, output = self.model(batch,
                                              has_targets=False,
                                              mask_softmax=1.0,
                                              teacher_forcing=1.0)
                    print(output.shape)
                    # get top-1
                    predicted_values, predicted_indices = torch.max(output,
                                                                    dim=-1)
                    print(predicted_values.shape)
                    print(predicted_indices.shape)

                    # convert predicted vocab indices to an actual sentence
                    predicted_seq = [
                        self.tgt_vocab.itos[c]
                        for c in predicted_indices.squeeze(0).tolist()
                    ]
                    # print('predicted_seq')
                    # print(predicted_seq)

                    # output is log_softmax so do exp()
                    predicted_values = predicted_values.exp()
                    # print('predicted_values')
                    # print(predicted_values)

                    # convert to list
                    predicted_values_ = predicted_values.squeeze(0).tolist()

                    # beam search
                    predicted_seq = self.beam_lm(''.join(predicted_seq[1:-1]),
                                                 predicted_values_[1:-1],
                                                 lines_raw[i])

                    # match case and punctuations
                    predicted_seq = self.match_case(predicted_seq,
                                                    lines_raw[i])

                    # do some post-processing to match submission output
                    predicted_seq = self.match_output(predicted_seq,
                                                      lines_cleaned[i])
                    print("{} {}".format(i, predicted_seq))

                    # write to file with line_id
                    writer.write(lines_id[i] + ',' + predicted_seq + '\n')