Exemple #1
0
def main():
    pred_file_path = 'test.csv'
    load_save_model = True
    lr = 1e-5
    batch_size = 8
    gpu = True
    torch.manual_seed(0)
    device = torch.device('cpu')
    if gpu:
        device = torch.device('cuda')

    tokenizer = BertTokenizer(vocab_file='publish/vocab.txt', max_len=512)
    _, known_token = load_dataset('TRAIN/Train_reviews.csv',
                                  'TRAIN/Train_labels.csv', tokenizer)
    dataset = load_review_dataset('TRAIN/TEST/Test_reviews.csv')
    dataset = Dataset(list(dataset.items()))
    dataloader = torch_data.DataLoader(dataset=dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       collate_fn=test_collate_fn(
                                           tokenizer, known_token))
    bert_pretraining = convert_tf_checkpoint_to_pytorch(
        './publish/bert_model.ckpt', './publish/bert_config.json')
    model = Model(bert_pretraining.bert)

    model = model.cuda()
    if load_save_model:
        model.load_state_dict(torch.load('./save_model/best.model'))

    pred_file = open(pred_file_path, mode='w', encoding='utf-8')

    pbar = tqdm()
    model.eval()
    for step, (batch_X, len_X, mask, batch_idx,
               origin_batch_X) in enumerate(dataloader):
        batch_X = batch_X.to(device)
        mask = mask.to(device)

        scores, gather_idx = model(batch_X, len_X, mask, None)
        (pred_seq_target, pred_match_target, pred_single_aspect_category_target, pred_single_opinion_category_target,\
            pred_cross_category_target, pred_single_aspect_polarity_target, pred_single_opinion_polarity_target,\
                pred_cross_polarity_target) = model.infer(scores, mask)

        label = []

        aspect_idx, opinion_idx = gather_idx
        for b in range(batch_X.shape[0]):
            _aspect_idx, _opinion_idx = aspect_idx[b], opinion_idx[b]
            if len(_aspect_idx) == 0 and len(_opinion_idx) == 0:
                label.append((batch_idx[b], '_', '_', '_', '_'))

            _aspect_cross, _opinion_cross = [
                False for i in range(len(_aspect_idx))
            ], [False for i in range(len(_opinion_idx))]
            for i in range(len(_aspect_idx)):
                for j in range(len(_opinion_idx)):
                    if pred_match_target[b][i, j] == 1:
                        _aspect_cross[i] = True
                        _opinion_cross[j] = True
                        category = ID2CATEGORY[pred_cross_category_target[b][
                            i, j]]
                        polarity = ID2POLARITY[pred_cross_polarity_target[b][
                            i, j]]
                        aspect = tokenizer.decode(
                            list(origin_batch_X[b, _aspect_idx[i]].cpu().
                                 detach().numpy())).replace(' ', '')
                        opinion = tokenizer.decode(
                            list(origin_batch_X[b,
                                                _opinion_idx[j]].cpu().detach(
                                                ).numpy())).replace(' ', '')
                        # aspect = tokenizer.decode(list(batch_X[b, _aspect_idx[i]].cpu().detach().numpy())).replace(' ', '')
                        # opinion = tokenizer.decode(list(batch_X[b, _opinion_idx[j]].cpu().detach().numpy())).replace(' ', '')
                        aspect_beg = len(
                            tokenizer.decode(
                                list(batch_X[b,
                                             1:_aspect_idx[i][0]].cpu().detach(
                                             ).numpy())).replace(' ', ''))
                        aspect_end = aspect_beg + len(aspect)
                        opinion_beg = len(
                            tokenizer.decode(
                                list(batch_X[b, 1:_opinion_idx[j][0]].cpu().
                                     detach().numpy())).replace(' ', ''))
                        opinion_end = opinion_beg + len(opinion)
                        label.append((batch_idx[b], aspect, opinion, category,
                                      polarity))
            for i in range(len(_aspect_idx)):
                if _aspect_cross[i] == False:
                    category = ID2CATEGORY[
                        pred_single_aspect_category_target[b][i]]
                    polarity = ID2POLARITY[
                        pred_single_aspect_polarity_target[b][i]]
                    aspect = tokenizer.decode(
                        list(origin_batch_X[
                            b,
                            _aspect_idx[i]].cpu().detach().numpy())).replace(
                                ' ', '')
                    # aspect = tokenizer.decode(list(batch_X[b, _aspect_idx[i]].cpu().detach().numpy())).replace(' ', '')
                    aspect_beg = len(
                        tokenizer.decode(
                            list(batch_X[b, 1:_aspect_idx[i][0]].cpu().detach(
                            ).numpy())).replace(' ', ''))
                    aspect_end = aspect_beg + len(aspect)
                    label.append(
                        (batch_idx[b], aspect, '_', category, polarity))
            for i in range(len(_opinion_idx)):
                if _opinion_cross[i] == False:
                    category = ID2CATEGORY[
                        pred_single_opinion_category_target[b][i]]
                    polarity = ID2POLARITY[
                        pred_single_opinion_polarity_target[b][i]]
                    opinion = tokenizer.decode(
                        list(origin_batch_X[
                            b,
                            _opinion_idx[i]].cpu().detach().numpy())).replace(
                                ' ', '')
                    # opinion = tokenizer.decode(list(batch_X[b, _opinion_idx[i]].cpu().detach().numpy())).replace(' ', '')
                    opinion_beg = len(
                        tokenizer.decode(
                            list(batch_X[b, 1:_opinion_idx[i][0]].cpu().detach(
                            ).numpy())).replace(' ', ''))
                    opinion_end = opinion_beg + len(opinion)
                    label.append(
                        (batch_idx[b], '_', opinion, category, polarity))

        for _label in label:
            _label = ','.join(list(map(lambda x: str(x), _label)))
            pred_file.write(_label + '\n')
        pbar.update(batch_size)
        pbar.set_description('step: %d' % step)
    pred_file.close()
    pbar.close()
Exemple #2
0
class BertCoder(object):
    def __init__(self,
                 filename,
                 bert_filename,
                 do_lower_case=False,
                 word_boundaries=False):
        self.filename = filename
        self.bert_filename = bert_filename
        self.do_lower_case = do_lower_case
        self.do_basic_tokenize = False
        # Hack around the fact that we need to know the word boundaries
        self.word_boundaries = word_boundaries

    def __len__(self):
        return self.tokenizer.vocab_size

    def fit(self, tokens):
        # NOTE: We allow the model to use default: do_basic_tokenize.
        # This potentially splits tokens into more tokens apart from subtokens:
        # eg. Mr.Doe -> Mr . D ##oe  (Note that . is not preceded by ##)
        # We take this into account when creating the token_flags in
        # function text_to_token_flags
        self.tokenizer = BertTokenizer(
            self.bert_filename,
            # do_basic_tokenize=self.do_basic_tokenize,
            do_lower_case=self.do_lower_case)
        return self

    def text_to_token_flags(self, text):
        """Return a tuple representing which subtokens are the beginning of a
        token. This is needed for NER using BERT:

        https://arxiv.org/pdf/1810.04805.pdf:

        "We use the representation of the first sub-token as the input to the
        token-level classifier over the NER label set."

        """
        text = self.tokenizer.basic_tokenizer._run_strip_accents(text)
        token_flags = []
        if self.do_lower_case:
            actual_split = text.lower().split()
        else:
            actual_split = text.split()

        bert_tokens = []
        for token in actual_split:
            local_bert_tokens = self.tokenizer.tokenize(token) or ['[UNK]']
            token_flags.append(1)
            for more in local_bert_tokens[1:]:
                token_flags.append(0)
            bert_tokens.extend(local_bert_tokens)
        # assert len(actual_tokens) == 0, [actual_tokens, actual_split, bert_tokens]
        assert len(token_flags) == len(bert_tokens), [
            actual_split, bert_tokens
        ]
        assert sum(token_flags) == len(actual_split)
        return tuple(token_flags)

    def encode(self, tokens):
        # Sometimes tokens include whitespace!
        # for sent_tokens in tokens:
        #     for token in sent_tokens:
        #         if ' ' in token:
        #             print(token)
        # The AIS dataset has a token ". .", for example.
        sent_tokens_no_ws = [[token.replace(' ', '') for token in sent_tokens]
                             for sent_tokens in tokens]
        texts = (' '.join(sent_tokens) for sent_tokens in sent_tokens_no_ws)
        if self.word_boundaries:
            encoded = tuple(self.text_to_token_flags(text) for text in texts)
            # encoded = tuple(tuple(0 if token.startswith('##') else 1
            #                       for token in self.tokenizer.tokenize(text))
            #                 for text in texts)
        else:
            # Adds CLS and SEP
            encoded = tuple(
                tuple(self.tokenizer.encode(text, add_special_tokens=True))
                for text in texts)
        return encoded

    def decode(self, ids):
        if self.word_boundaries:
            return []
        else:
            # NOTE: we only encode a single sentence, so use [0]
            return tuple(
                tuple(
                    self.tokenizer.decode(sent_ids,
                                          clean_up_tokenization_spaces=False)
                    [0].split()) for sent_ids in ids)

    def load(self, filename):
        self.tokenizer = BertTokenizer(
            filename,
            # do_basic_tokenize=self.do_basic_tokenize,
            do_lower_case=self.do_lower_case)
        return self

    def save(self, filename):
        copyfile(self.bert_filename, filename)