def main(): pred_file_path = 'test.csv' load_save_model = True lr = 1e-5 batch_size = 8 gpu = True torch.manual_seed(0) device = torch.device('cpu') if gpu: device = torch.device('cuda') tokenizer = BertTokenizer(vocab_file='publish/vocab.txt', max_len=512) _, known_token = load_dataset('TRAIN/Train_reviews.csv', 'TRAIN/Train_labels.csv', tokenizer) dataset = load_review_dataset('TRAIN/TEST/Test_reviews.csv') dataset = Dataset(list(dataset.items())) dataloader = torch_data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, collate_fn=test_collate_fn( tokenizer, known_token)) bert_pretraining = convert_tf_checkpoint_to_pytorch( './publish/bert_model.ckpt', './publish/bert_config.json') model = Model(bert_pretraining.bert) model = model.cuda() if load_save_model: model.load_state_dict(torch.load('./save_model/best.model')) pred_file = open(pred_file_path, mode='w', encoding='utf-8') pbar = tqdm() model.eval() for step, (batch_X, len_X, mask, batch_idx, origin_batch_X) in enumerate(dataloader): batch_X = batch_X.to(device) mask = mask.to(device) scores, gather_idx = model(batch_X, len_X, mask, None) (pred_seq_target, pred_match_target, pred_single_aspect_category_target, pred_single_opinion_category_target,\ pred_cross_category_target, pred_single_aspect_polarity_target, pred_single_opinion_polarity_target,\ pred_cross_polarity_target) = model.infer(scores, mask) label = [] aspect_idx, opinion_idx = gather_idx for b in range(batch_X.shape[0]): _aspect_idx, _opinion_idx = aspect_idx[b], opinion_idx[b] if len(_aspect_idx) == 0 and len(_opinion_idx) == 0: label.append((batch_idx[b], '_', '_', '_', '_')) _aspect_cross, _opinion_cross = [ False for i in range(len(_aspect_idx)) ], [False for i in range(len(_opinion_idx))] for i in range(len(_aspect_idx)): for j in range(len(_opinion_idx)): if pred_match_target[b][i, j] == 1: _aspect_cross[i] = True _opinion_cross[j] = True category = ID2CATEGORY[pred_cross_category_target[b][ i, j]] polarity = ID2POLARITY[pred_cross_polarity_target[b][ i, j]] aspect = tokenizer.decode( list(origin_batch_X[b, _aspect_idx[i]].cpu(). detach().numpy())).replace(' ', '') opinion = tokenizer.decode( list(origin_batch_X[b, _opinion_idx[j]].cpu().detach( ).numpy())).replace(' ', '') # aspect = tokenizer.decode(list(batch_X[b, _aspect_idx[i]].cpu().detach().numpy())).replace(' ', '') # opinion = tokenizer.decode(list(batch_X[b, _opinion_idx[j]].cpu().detach().numpy())).replace(' ', '') aspect_beg = len( tokenizer.decode( list(batch_X[b, 1:_aspect_idx[i][0]].cpu().detach( ).numpy())).replace(' ', '')) aspect_end = aspect_beg + len(aspect) opinion_beg = len( tokenizer.decode( list(batch_X[b, 1:_opinion_idx[j][0]].cpu(). detach().numpy())).replace(' ', '')) opinion_end = opinion_beg + len(opinion) label.append((batch_idx[b], aspect, opinion, category, polarity)) for i in range(len(_aspect_idx)): if _aspect_cross[i] == False: category = ID2CATEGORY[ pred_single_aspect_category_target[b][i]] polarity = ID2POLARITY[ pred_single_aspect_polarity_target[b][i]] aspect = tokenizer.decode( list(origin_batch_X[ b, _aspect_idx[i]].cpu().detach().numpy())).replace( ' ', '') # aspect = tokenizer.decode(list(batch_X[b, _aspect_idx[i]].cpu().detach().numpy())).replace(' ', '') aspect_beg = len( tokenizer.decode( list(batch_X[b, 1:_aspect_idx[i][0]].cpu().detach( ).numpy())).replace(' ', '')) aspect_end = aspect_beg + len(aspect) label.append( (batch_idx[b], aspect, '_', category, polarity)) for i in range(len(_opinion_idx)): if _opinion_cross[i] == False: category = ID2CATEGORY[ pred_single_opinion_category_target[b][i]] polarity = ID2POLARITY[ pred_single_opinion_polarity_target[b][i]] opinion = tokenizer.decode( list(origin_batch_X[ b, _opinion_idx[i]].cpu().detach().numpy())).replace( ' ', '') # opinion = tokenizer.decode(list(batch_X[b, _opinion_idx[i]].cpu().detach().numpy())).replace(' ', '') opinion_beg = len( tokenizer.decode( list(batch_X[b, 1:_opinion_idx[i][0]].cpu().detach( ).numpy())).replace(' ', '')) opinion_end = opinion_beg + len(opinion) label.append( (batch_idx[b], '_', opinion, category, polarity)) for _label in label: _label = ','.join(list(map(lambda x: str(x), _label))) pred_file.write(_label + '\n') pbar.update(batch_size) pbar.set_description('step: %d' % step) pred_file.close() pbar.close()
class BertCoder(object): def __init__(self, filename, bert_filename, do_lower_case=False, word_boundaries=False): self.filename = filename self.bert_filename = bert_filename self.do_lower_case = do_lower_case self.do_basic_tokenize = False # Hack around the fact that we need to know the word boundaries self.word_boundaries = word_boundaries def __len__(self): return self.tokenizer.vocab_size def fit(self, tokens): # NOTE: We allow the model to use default: do_basic_tokenize. # This potentially splits tokens into more tokens apart from subtokens: # eg. Mr.Doe -> Mr . D ##oe (Note that . is not preceded by ##) # We take this into account when creating the token_flags in # function text_to_token_flags self.tokenizer = BertTokenizer( self.bert_filename, # do_basic_tokenize=self.do_basic_tokenize, do_lower_case=self.do_lower_case) return self def text_to_token_flags(self, text): """Return a tuple representing which subtokens are the beginning of a token. This is needed for NER using BERT: https://arxiv.org/pdf/1810.04805.pdf: "We use the representation of the first sub-token as the input to the token-level classifier over the NER label set." """ text = self.tokenizer.basic_tokenizer._run_strip_accents(text) token_flags = [] if self.do_lower_case: actual_split = text.lower().split() else: actual_split = text.split() bert_tokens = [] for token in actual_split: local_bert_tokens = self.tokenizer.tokenize(token) or ['[UNK]'] token_flags.append(1) for more in local_bert_tokens[1:]: token_flags.append(0) bert_tokens.extend(local_bert_tokens) # assert len(actual_tokens) == 0, [actual_tokens, actual_split, bert_tokens] assert len(token_flags) == len(bert_tokens), [ actual_split, bert_tokens ] assert sum(token_flags) == len(actual_split) return tuple(token_flags) def encode(self, tokens): # Sometimes tokens include whitespace! # for sent_tokens in tokens: # for token in sent_tokens: # if ' ' in token: # print(token) # The AIS dataset has a token ". .", for example. sent_tokens_no_ws = [[token.replace(' ', '') for token in sent_tokens] for sent_tokens in tokens] texts = (' '.join(sent_tokens) for sent_tokens in sent_tokens_no_ws) if self.word_boundaries: encoded = tuple(self.text_to_token_flags(text) for text in texts) # encoded = tuple(tuple(0 if token.startswith('##') else 1 # for token in self.tokenizer.tokenize(text)) # for text in texts) else: # Adds CLS and SEP encoded = tuple( tuple(self.tokenizer.encode(text, add_special_tokens=True)) for text in texts) return encoded def decode(self, ids): if self.word_boundaries: return [] else: # NOTE: we only encode a single sentence, so use [0] return tuple( tuple( self.tokenizer.decode(sent_ids, clean_up_tokenization_spaces=False) [0].split()) for sent_ids in ids) def load(self, filename): self.tokenizer = BertTokenizer( filename, # do_basic_tokenize=self.do_basic_tokenize, do_lower_case=self.do_lower_case) return self def save(self, filename): copyfile(self.bert_filename, filename)