def classify(extractions, config, model, sents=None): """Apply the classification models (for Sentiment and Attribute Classification). Args: extractions (list of dict): the partial extraction results by the pairing model config (dict): the model configuration model (MultiTaskNet): the model in pytorch Returns: list of dict: the extraction results with attribute name and sentiment score assigned to the field "attribute" and "sentiment". """ phrases = [] index = [] # print('Prepare classification data') for sid, sent in enumerate(extractions): for eid, ext in enumerate(sent['extractions']): if 'asc' in config['name']: if 'aspect' in ext: phrase = ' '.join(sents[ext['sid']]) + '\t' + ext['aspect'] else: phrase = ' '.join(sents[ext['sid']]) + '\t' + ext['opinion'] else: if 'aspect' in ext and 'opinion' in ext: phrase = ext['opinion'] + ' ' + ext['aspect'] elif 'aspect' in ext: phrase = ext['aspect'] else: phrase = ext['opinion'] phrases.append(phrase) index.append((sid, eid)) dataset = SnippextDataset(phrases, config['vocab'], config['name']) iterator = data.DataLoader(dataset=dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=SnippextDataset.pad) # prediction Y_hat = [] with torch.no_grad(): # print('Classification') for i, batch in enumerate(iterator): words, x, is_heads, tags, mask, y, seqlens, taskname = batch taskname = taskname[0] _, _, y_hat = model(x, y, task=taskname) # y_hat: (N, T) Y_hat.extend(y_hat.cpu().numpy().tolist()) for i in range(len(phrases)): attr = dataset.idx2tag[Y_hat[i]] sid, eid = index[i] if 'asc' in config['name']: extractions[sid]['extractions'][eid]['sentiment'] = attr else: extractions[sid]['extractions'][eid]['attribute'] = attr return extractions
def do_tagging(text, config, model): """Apply the tagging model. Args: text (str): the input paragraph config (dict): the model configuration model (MultiTaskNet): the model in pytorch Returns: list of list of str: the tokens in each sentences list of list of int: each token's starting position in the original text list of list of str: the tags assigned to each token """ # load data and tokenization source = [] token_pos_list = [] # print('Tokenize sentences') for sent in sent_tokenizer(text): tokens = [token.text for token in sent] token_pos = [token.idx for token in sent] source.append(tokens) token_pos_list.append(token_pos) dataset = SnippextDataset(source, config['vocab'], config['name'], max_len=64) iterator = data.DataLoader(dataset=dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=SnippextDataset.pad) # prediction model.eval() Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], [] with torch.no_grad(): # print('Tagging') for i, batch in enumerate(iterator): try: words, x, is_heads, tags, mask, y, seqlens, taskname = batch taskname = taskname[0] _, _, y_hat = model(x, y, task=taskname) # y_hat: (N, T) Words.extend(words) Is_heads.extend(is_heads) Tags.extend(tags) Y.extend(y.numpy().tolist()) Y_hat.extend(y_hat.cpu().numpy().tolist()) except: print('error @', batch) # gets results and save results = [] for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat): y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1] # remove the first and the last token preds = [dataset.idx2tag[hat] for hat in y_hat][1:-1] results.append(preds) return source, token_pos_list, results
configs = {conf['name']: conf for conf in configs} config = configs[task] config_list = [config] trainset = config['trainset'] validset = config['validset'] testset = config['testset'] unlabeled = config['unlabeled'] task_type = config['task_type'] vocab = config['vocab'] tasknames = [task] # train dataset train_dataset = SnippextDataset(trainset, vocab, task, lm=hp.lm, max_len=hp.max_len) # train dataset augmented augment_dataset = SnippextDataset(trainset, vocab, task, lm=hp.lm, max_len=hp.max_len, augment_index=hp.augment_index, augment_op=hp.augment_op) # dev set valid_dataset = SnippextDataset(validset, vocab, task, lm=hp.lm) # test set test_dataset = SnippextDataset(testset, vocab, task, lm=hp.lm)
task = hp.task # create the tag of the run run_tag = 'baseline_task_%s_lm_%s_batch_size_%d_run_id_%d' % ( task, hp.lm, hp.batch_size, hp.run_id) # load task configuration configs = json.load(open('configs.json')) configs = {conf['name']: conf for conf in configs} config = configs[task] trainset = config['trainset'] validset = config['validset'] testset = config['testset'] task_type = config['task_type'] vocab = config['vocab'] tasknames = [task] # load train/dev/test sets train_dataset = SnippextDataset(trainset, vocab, task, lm=hp.lm, max_len=hp.max_len) valid_dataset = SnippextDataset(validset, vocab, task, lm=hp.lm) test_dataset = SnippextDataset(testset, vocab, task, lm=hp.lm) # run the training process initialize_and_train(config, train_dataset, valid_dataset, test_dataset, hp, run_tag)
def do_pairing(all_tokens, all_tags, config, model): """Apply the pairing model. Args: all_tokens (list of list of str): the tokenized text all_tags (list of list of str): the tags assigned to each token config (dict): the model configuration model (MultiTaskNet): the model in pytorch Returns: list of dict: For each sentence, the list of extracted opinions/experiences from the sentence. Each dictionary includes an aspect term and an opinion term and the start/end position of the aspect/opinion term. """ samples = [] sent_ids = [] candidates = [] positions = [] all_spans = {} sid = 0 for tokens, tags in zip(all_tokens, all_tags): aspects = [] opinions = [] # find aspects # find opinions for i, tag in enumerate(tags): if tag[0] == 'B': start = i end = i while end + 1 < len(tags) and tags[end + 1][0] == 'I': end += 1 if tag == 'B-AS': aspects.append((start, end)) all_spans[(sid, start, end)] = {'aspect': ' '.join(tokens[start:end+1]), 'sid': sid, 'asp_start': start, 'asp_end': end} else: opinions.append((start, end)) all_spans[(sid, start, end)] = {'opinion': ' '.join(tokens[start:end+1]), 'sid': sid, 'op_start': start, 'op_end': end} candidate_pairs = [] for asp in aspects: for opi in opinions: candidate_pairs.append((asp, opi)) candidate_pairs.sort(key=lambda ao: abs(ao[0][0] - ao[1][0])) for asp, opi in candidate_pairs: asp_start, asp_end = asp op_start, op_end = opi token_ids = [] for i in range(asp_start, asp_end + 1): token_ids.append((sid, i)) for i in range(op_start, op_end + 1): token_ids.append((sid, i)) if op_start < asp_start: samples.append(' '.join(tokens) + ' [SEP] ' + \ ' '.join(tokens[op_start:op_end+1]) + ' ' + \ ' '.join(tokens[asp_start:asp_end+1])) else: samples.append(' '.join(tokens) + ' [SEP] ' + \ ' '.join(tokens[asp_start:asp_end+1]) + ' ' + \ ' '.join(tokens[op_start:op_end+1])) sent_ids.append(sid) candidates.append({'opinion': ' '.join(tokens[op_start:op_end+1]), 'aspect': ' '.join(tokens[asp_start:asp_end+1]), 'sid': sid, 'asp_start': asp_start, 'asp_end': asp_end, 'op_start': op_start, 'op_end': op_end}) positions.append(token_ids) sid += 1 dataset = SnippextDataset(samples, config['vocab'], config['name']) iterator = data.DataLoader(dataset=dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=SnippextDataset.pad) # prediction Y_hat = [] Y = [] with torch.no_grad(): for i, batch in enumerate(iterator): words, x, is_heads, tags, mask, y, seqlens, taskname = batch taskname = taskname[0] _, y, y_hat = model(x, y, task=taskname) # y_hat: (N, T) Y_hat.extend(y_hat.cpu().numpy().tolist()) Y.extend(y.cpu().numpy().tolist()) results = [] for tokens in all_tokens: results.append({'sentence': ' '.join(tokens), 'extractions': []}) used = set([]) for i, yhat in enumerate(Y_hat): phrase = samples[i].split(' [SEP] ')[1] # print(phrase, yhat) if yhat == 1: # do some filtering assigned = False for tid in positions[i]: if tid in used: assigned = True break if not assigned: results[sent_ids[i]]['extractions'].append(candidates[i]) for tid in positions[i]: used.add(tid) # drop from all_spans sid = candidates[i]['sid'] del all_spans[(sid, candidates[i]['asp_start'], candidates[i]['asp_end'])] del all_spans[(sid, candidates[i]['op_start'], candidates[i]['op_end'])] # add aspects/opinions that are not paired for sid, start, end in all_spans: results[sid]['extractions'].append(all_spans[(sid, start, end)]) return results