def load_qianyan_data(file_path, task_def): data_format = task_def.data_type rows = [] for line in open(file_path, encoding="utf-8"): fields = line.strip("\n").split("\t") assert data_format == DataFormat.PremiseAndOneHypothesis if len(fields) == 4: row = { "uid": fields[0], "label": fields[1], "premise": fields[2], "hypothesis": fields[3] } elif len(fields) == 3: # test row = { "uid": fields[0], "label": "0", "premise": fields[1], "hypothesis": fields[2] } else: raise ValueError(f"invalid line found: {line}") task_obj = tasks.get_task_obj(task_def) if task_obj is not None: row["label"] = task_obj.input_parse_label(row["label"]) rows.append(row) return rows
def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0, fwd_type=0, embed=None): if fwd_type == 2: assert embed is not None sequence_output, pooled_output = self.embed_forward( embed, attention_mask) elif fwd_type == 1: return self.embed_encode(input_ids, token_type_ids, attention_mask) else: sequence_output, pooled_output, _ = self.encode( input_ids, token_type_ids, attention_mask) decoder_opt = self.decoder_opt[task_id] task_type = self.task_types[task_id] task_obj = tasks.get_task_obj(self.task_def_list[task_id]) if task_obj is not None: logits = task_obj.train_forward(sequence_output, pooled_output, premise_mask, hyp_mask, decoder_opt, self.dropout_list[task_id], self.scoring_list[task_id]) return logits elif task_type == TaskType.Span: assert decoder_opt != 1 sequence_output = self.dropout_list[task_id](sequence_output) logits = self.scoring_list[task_id](sequence_output) start_scores, end_scores = logits.split(1, dim=-1) start_scores = start_scores.squeeze(-1) end_scores = end_scores.squeeze(-1) return start_scores, end_scores elif task_type == TaskType.SeqenceLabeling: pooled_output = sequence_output pooled_output = self.dropout_list[task_id](pooled_output) pooled_output = pooled_output.contiguous().view( -1, pooled_output.size(2)) logits = self.scoring_list[task_id](pooled_output) return logits elif task_type == TaskType.MaskLM: sequence_output = self.dropout_list[task_id](sequence_output) logits = self.scoring_list[task_id](sequence_output) return logits else: if decoder_opt == 1: max_query = hyp_mask.size(1) assert max_query > 0 assert premise_mask is not None assert hyp_mask is not None hyp_mem = sequence_output[:, :max_query, :] logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask) else: pooled_output = self.dropout_list[task_id](pooled_output) logits = self.scoring_list[task_id](pooled_output) return logits
def load( path, is_train=True, maxlen=512, factor=1.0, task_def=None, bert_model="bert-base-uncased", do_lower_case=True, printable=True, ): task_type = task_def.task_type assert task_type is not None if task_type == TaskType.MaskLM: def load_mlm_data(path): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(bert_model, cache_dir=".cache") vocab_words = list(tokenizer.vocab.keys()) data = load_loose_json(path) docs = [] for doc in data: paras = doc["text"].split("\n\n") paras = [ para.strip() for para in paras if len(para.strip()) > 0 ] tokens = [tokenizer.tokenize(para) for para in paras] docs.append(tokens) return docs, tokenizer return load_mlm_data(path) with open(path, "r", encoding="utf-8") as reader: data = [] cnt = 0 for line in reader: sample = json.loads(line) sample["factor"] = factor cnt += 1 if is_train: task_obj = tasks.get_task_obj(task_def) if task_obj is not None and not task_obj.input_is_valid_sample( sample, maxlen): continue if (task_type == TaskType.Ranking) and ( len(sample["token_id"][0]) > maxlen or len(sample["token_id"][1]) > maxlen): continue if (task_type != TaskType.Ranking) and (len( sample["token_id"]) > maxlen): continue data.append(sample) if printable: print("Loaded {} samples out of {}".format(len(data), cnt)) return data, None
def load(path, is_train=True, maxlen=512, factor=1.0, task_def=None, bert_model='bert-base-uncased', do_lower_case=True, printable=True): task_type = task_def.task_type assert task_type is not None if task_type == TaskType.MaskLM: def load_mlm_data(path): from pytorch_pretrained_bert.tokenization import BertTokenizer tokenizer = BertTokenizer.from_pretrained( bert_model, do_lower_case=do_lower_case) vocab_words = list(tokenizer.vocab.keys()) data = load_loose_json(path) docs = [] for doc in data: paras = doc['text'].split('\n\n') paras = [ para.strip() for para in paras if len(para.strip()) > 0 ] tokens = [tokenizer.tokenize(para) for para in paras] docs.append(tokens) return docs, tokenizer return load_mlm_data(path) with open(path, 'r', encoding='utf-8') as reader: data = [] cnt = 0 for line in reader: sample = json.loads(line) sample['factor'] = factor cnt += 1 if is_train: task_obj = tasks.get_task_obj(task_def) if task_obj is not None and not task_obj.input_is_valid_sample( sample, maxlen): continue if (task_type == TaskType.Ranking) and ( len(sample['token_id'][0]) > maxlen or len(sample['token_id'][1]) > maxlen): continue if (task_type != TaskType.Ranking) and (len( sample['token_id']) > maxlen): continue #print("[chunhui]load sample = {}".format(sample)) data.append(sample) if printable: print('Loaded {} samples out of {}'.format(len(data), cnt)) return data, None
def predict(self, batch_meta, batch_data): self.network.eval() task_id = batch_meta['task_id'] task_def = TaskDef.from_dict(batch_meta['task_def']) task_type = task_def.task_type task_obj = tasks.get_task_obj(task_def) inputs = batch_data[:batch_meta['input_len']] if len(inputs) == 3: inputs.append(None) inputs.append(None) inputs.append(task_id) score = self.mnetwork(*inputs) if task_obj is not None: score, predict = task_obj.test_predict(score) elif task_type == TaskType.Ranking: score = score.contiguous().view(-1, batch_meta['pairwise_size']) assert task_type == TaskType.Ranking score = F.softmax(score, dim=1) score = score.data.cpu() score = score.numpy() predict = np.zeros(score.shape, dtype=int) positive = np.argmax(score, axis=1) for idx, pos in enumerate(positive): predict[idx, pos] = 1 predict = predict.reshape(-1).tolist() score = score.reshape(-1).tolist() return score, predict, batch_meta['true_label'] elif task_type == TaskType.SeqenceLabeling: mask = batch_data[batch_meta['mask']] score = score.contiguous() score = score.data.cpu() score = score.numpy() predict = np.argmax(score, axis=1).reshape(mask.size()).tolist() valied_lenght = mask.sum(1).tolist() final_predict = [] for idx, p in enumerate(predict): final_predict.append(p[:valied_lenght[idx]]) score = score.reshape(-1).tolist() return score, final_predict, batch_meta['label'] elif task_type == TaskType.Span: start, end = score predictions = [] if self.config['encoder_type'] == EncoderModelType.BERT: import experiments.squad.squad_utils as mrc_utils scores, predictions = mrc_utils.extract_answer( batch_meta, batch_data, start, end, self.config.get('max_answer_len', 5), do_lower_case=self.config.get('do_lower_case', False)) return scores, predictions, batch_meta['answer'] else: raise ValueError("Unknown task_type: %s" % task_type) return score, predict, batch_meta['label']
def load_clue_data(file_path, task_def): print("task_def={}".format(task_def)) data_format = task_def.data_type task_type = task_def.task_type label_dict = task_def.label_vocab if task_type == TaskType.Ranking: assert data_format == DataFormat.PremiseAndMultiHypothesis rows = [] uid = 0 for line in open(file_path, encoding="utf-8"): record = json.loads(line) if task_def.name == "iflytek": label = record.get("label_des") elif task_def.name == "tnews": label = record.get("label_desc") else: label = record.get("label") if label is None: label = "0" if data_format == DataFormat.PremiseOnly: if task_def.name == "wsc": premise = read_wsc(record) else: premise = record["sentence"] row = {"premise": premise} elif data_format == DataFormat.PremiseAndOneHypothesis: if task_def.name == "csl": premise = record["abst"] hypothesis = ' '.join(record["keyword"]) else: premise = record["sentence1"] hypothesis = record["sentence2"] row = {"premise": premise, "hypothesis": hypothesis} else: raise ValueError("not implemented yet") task_obj = tasks.get_task_obj(task_def) row["label"] = label if task_obj is not None: row["label"] = task_obj.input_parse_label(row["label"]) if row["label"] is None: continue row["uid"] = f"{uid}" rows.append(row) uid += 1 return rows
def __init__(self, opt, bert_config=None, initial_from_local=False): super(SANBertNetwork, self).__init__() self.dropout_list = nn.ModuleList() if opt['encoder_type'] not in EncoderModelType._value2member_map_: raise ValueError("encoder_type is out of pre-defined types") self.encoder_type = opt['encoder_type'] self.preloaded_config = None literal_encoder_type = EncoderModelType(self.encoder_type).name.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[ literal_encoder_type] self.preloaded_config = config_class.from_dict( opt) # load config from opt self.preloaded_config.output_hidden_states = True # return all hidden states self.bert = model_class(self.preloaded_config) hidden_size = self.bert.config.hidden_size if opt.get('dump_feature', False): self.opt = opt return if opt['update_bert_opt'] > 0: for p in self.bert.parameters(): p.requires_grad = False task_def_list = opt['task_def_list'] self.task_def_list = task_def_list self.decoder_opt = [] self.task_types = [] for task_id, task_def in enumerate(task_def_list): self.decoder_opt.append( generate_decoder_opt(task_def.enable_san, opt['answer_opt'])) self.task_types.append(task_def.task_type) # create output header self.scoring_list = nn.ModuleList() self.dropout_list = nn.ModuleList() for task_id in range(len(task_def_list)): task_def: TaskDef = task_def_list[task_id] lab = task_def.n_class decoder_opt = self.decoder_opt[task_id] task_type = self.task_types[task_id] task_dropout_p = opt[ 'dropout_p'] if task_def.dropout_p is None else task_def.dropout_p dropout = DropoutWrapper(task_dropout_p, opt['vb_dropout']) self.dropout_list.append(dropout) task_obj = tasks.get_task_obj(task_def) if task_obj is not None: out_proj = task_obj.train_build_task_layer(decoder_opt, hidden_size, lab, opt, prefix='answer', dropout=dropout) elif task_type == TaskType.Span: assert decoder_opt != 1 out_proj = nn.Linear(hidden_size, 2) elif task_type == TaskType.SeqenceLabeling: out_proj = nn.Linear(hidden_size, lab) elif task_type == TaskType.MaskLM: if opt['encoder_type'] == EncoderModelType.ROBERTA: # TODO: xiaodl out_proj = MaskLmHeader( self.bert.embeddings.word_embeddings.weight) else: out_proj = MaskLmHeader( self.bert.embeddings.word_embeddings.weight) else: if decoder_opt == 1: out_proj = SANClassifier(hidden_size, hidden_size, lab, opt, prefix='answer', dropout=dropout) else: out_proj = nn.Linear(hidden_size, lab) self.scoring_list.append(out_proj) self.opt = opt self._my_init() # if not loading from local, loading model weights from pre-trained model, after initialization if not initial_from_local: config_class, model_class, tokenizer_class = MODEL_CLASSES[ literal_encoder_type] self.bert = model_class.from_pretrained( opt['init_checkpoint'], config=self.preloaded_config)
def load_data(file_path, task_def): data_format = task_def.data_type task_type = task_def.task_type label_dict = task_def.label_vocab if task_type == TaskType.Ranking: assert data_format == DataFormat.PremiseAndMultiHypothesis rows = [] for line in open(file_path, encoding="utf-8"): #print(line) fields = line.strip("\n").split("\t") if data_format == DataFormat.PremiseOnly: assert len(fields) == 3 row = {"uid": fields[0], "label": fields[1], "premise": fields[2]} elif data_format == DataFormat.PremiseAndOneHypothesis: assert len(fields) == 4 row = { "uid": fields[0], "label": fields[1], "premise": fields[2], "hypothesis": fields[3] } elif data_format == DataFormat.PremiseAndMultiHypothesis: assert len(fields) > 5 row = { "uid": fields[0], "ruid": fields[1].split(","), "label": fields[2], "premise": fields[3], "hypothesis": fields[4:] } elif data_format == DataFormat.Seqence: row = { "uid": fields[0], "label": eval(fields[1]), "premise": eval(fields[2]) } elif data_format == DataFormat.MRC: row = { "uid": fields[0], "label": fields[1], "premise": fields[2], "hypothesis": fields[3] } elif data_format == DataFormat.SimPair: if len(fields) < 4: continue row = { "uid": fields[0], "label": fields[1], "text_a": fields[2], "text_b": fields[3] } else: raise ValueError(data_format) task_obj = tasks.get_task_obj(task_def) if task_obj is not None: row["label"] = task_obj.input_parse_label(row["label"]) elif task_type == TaskType.Ranking: labels = row["label"].split(",") if label_dict is not None: labels = [label_dict[label] for label in labels] else: labels = [float(label) for label in labels] row["label"] = int(np.argmax(labels)) row["olabel"] = labels elif task_type == TaskType.Span: pass # don't process row label elif task_type == TaskType.SeqenceLabeling: assert type(row["label"]) is list row["label"] = [label_dict[label] for label in row["label"]] rows.append(row) return rows
def collate_fn(self, batch): task_id = batch[0]["task"]["task_id"] task_def = batch[0]["task"]["task_def"] new_batch = [] for sample in batch: assert sample["task"]["task_id"] == task_id assert sample["task"]["task_def"] == task_def new_batch.append(sample["sample"]) task_type = task_def.task_type data_type = task_def.data_type batch = new_batch if task_type == TaskType.Ranking: batch = self.rebatch(batch) # prepare model input batch_info, batch_data = self._prepare_model_input(batch, data_type) batch_info[ 'task_id'] = task_id # used for select correct decoding head batch_info['input_len'] = len( batch_data) # used to select model inputs # select different loss function and other difference in training and testing # DataLoader will convert any unknown type objects to dict, # the conversion logic also convert Enum to repr(Enum), which is a string and undesirable # If we convert object to dict in advance, DataLoader will do nothing batch_info['task_def'] = task_def.__dict__ batch_info[ 'pairwise_size'] = self.pairwise_size # need for ranking task # add label labels = [sample['label'] for sample in batch] task_obj = tasks.get_task_obj(task_def) if self.is_train: # in training model, label is used by Pytorch, so would be tensor if task_obj is not None: batch_data.append(task_obj.train_prepare_label(labels)) batch_info['label'] = len(batch_data) - 1 elif task_type == TaskType.Ranking: batch_data.append(torch.LongTensor(labels)) batch_info['label'] = len(batch_data) - 1 elif task_type == TaskType.Span: start = [sample['start_position'] for sample in batch] end = [sample['end_position'] for sample in batch] batch_data.append( (torch.LongTensor(start), torch.LongTensor(end))) # unify to one type of label batch_info['label'] = len(batch_data) - 1 #batch_data.extend([torch.LongTensor(start), torch.LongTensor(end)]) elif task_type == TaskType.SeqenceLabeling: batch_size = self._get_batch_size(batch) tok_len = self._get_max_len(batch, key='token_id') tlab = torch.LongTensor(batch_size, tok_len).fill_(-1) for i, label in enumerate(labels): ll = len(label) tlab[i, :ll] = torch.LongTensor(label) batch_data.append(tlab) batch_info['label'] = len(batch_data) - 1 elif task_type == TaskType.MaskLM: batch_size = self._get_batch_size(batch) tok_len = self._get_max_len(batch, key='token_id') tlab = torch.LongTensor(batch_size, tok_len).fill_(-1) for i, label in enumerate(labels): ll = len(label) tlab[i, :ll] = torch.LongTensor(label) labels = torch.LongTensor( [sample['nsp_lab'] for sample in batch]) batch_data.append((tlab, labels)) batch_info['label'] = len(batch_data) - 1 # soft label generated by ensemble models for knowledge distillation if self.soft_label_on and 'softlabel' in batch[0]: sortlabels = [sample['softlabel'] for sample in batch] sortlabels = task_obj.train_prepare_soft_labels(sortlabels) batch_info['soft_label'] = sortlabels else: # in test model, label would be used for evaluation if task_obj is not None: task_obj.test_prepare_label(batch_info, labels) else: batch_info['label'] = labels if task_type == TaskType.Ranking: batch_info['true_label'] = [ sample['true_label'] for sample in batch ] if task_type == TaskType.Span: batch_info['token_to_orig_map'] = [ sample['token_to_orig_map'] for sample in batch ] batch_info['token_is_max_context'] = [ sample['token_is_max_context'] for sample in batch ] batch_info['doc_offset'] = [ sample['doc_offset'] for sample in batch ] batch_info['doc'] = [sample['doc'] for sample in batch] batch_info['tokens'] = [ sample['tokens'] for sample in batch ] batch_info['answer'] = [ sample['answer'] for sample in batch ] batch_info['uids'] = [sample['uid'] for sample in batch] # used in scoring return batch_info, batch_data
def load_model_for_viz_1(task_def_path, checkpoint_path, input_path, model_type='bert-base-cased', do_lower_case=False, use_cuda=True): # load task info task = os.path.splitext(os.path.basename(task_def_path))[0] task_defs = TaskDefs(task_def_path) assert task in task_defs._task_type_map assert task in task_defs._data_type_map assert task in task_defs._metric_meta_map prefix = task.split('_')[0] task_def = task_defs.get_task_def(prefix) data_type = task_defs._data_type_map[task] task_type = task_defs._task_type_map[task] metric_meta = task_defs._metric_meta_map[task] # load model assert os.path.exists(checkpoint_path) state_dict = torch.load(checkpoint_path) config = state_dict['config'] config["cuda"] = use_cuda device = torch.device("cuda" if use_cuda else "cpu") task_def = task_defs.get_task_def(prefix) task_def_list = [task_def] config['task_def_list'] = task_def_list ## temp fix config['fp16'] = False config['answer_opt'] = 0 config['adv_train'] = False #del state_dict['optimizer'] config['output_attentions'] = True config['local_rank'] = -1 model = MTDNNModel(config, device, state_dict=state_dict) encoder_type = config.get('encoder_type', EncoderModelType.BERT) root = os.path.basename(task_def_path) literal_model_type = model_type.split('-')[0].upper() encoder_model = EncoderModelType[literal_model_type] literal_model_type = literal_model_type.lower() mt_dnn_suffix = literal_model_type if 'base' in model_type: mt_dnn_suffix += "_base" elif 'large' in model_type: mt_dnn_suffix += "_large" # load tokenizer config_class, model_class, tokenizer_class = MODEL_CLASSES[ literal_model_type] tokenizer = tokenizer_class.from_pretrained(model_type, do_lower_case=do_lower_case) # load data prep_input = input_path test_data_set = SingleTaskDataset(prep_input, False, maxlen=512, task_id=0, task_def=task_def) collater = Collater(is_train=False, encoder_type=encoder_type) test_data = DataLoader(test_data_set, batch_size=1, collate_fn=collater.collate_fn, pin_memory=True) idx = 0 results = [] for batch_meta, batch_data in tqdm(test_data): if idx < 360: idx += 1 continue batch_meta, batch_data = Collater.patch_data(device, batch_meta, batch_data) model.network.eval() task_id = batch_meta['task_id'] task_def = TaskDef.from_dict(batch_meta['task_def']) task_type = task_def.task_type task_obj = tasks.get_task_obj(task_def) inputs = batch_data[:batch_meta['input_len']] if len(inputs) == 3: inputs.append(None) inputs.append(None) inputs.append(task_id) input_ids = inputs[0] token_type_ids = inputs[1] attention = model.mnetwork.module.bert( input_ids, token_type_ids=token_type_ids)[-1] batch_size = batch_data[0].shape[0] for i in range(batch_size): attention = tuple([item[i:i + 1, :, :, :] for item in attention]) input_id_list = input_ids[i].tolist() tokens = tokenizer.convert_ids_to_tokens(input_id_list) idx_sep = listRightIndex(tokens, '[SEP]') + 1 tokens = tokens[:idx_sep] attention = tuple( [item[:, :, :idx_sep, :idx_sep] for item in attention]) results.append((attention, tokens)) idx += batch_size return results
def predict(self, batch_meta, batch_data): self.network.eval() task_id = batch_meta["task_id"] task_def = TaskDef.from_dict(batch_meta["task_def"]) task_type = task_def.task_type task_obj = tasks.get_task_obj(task_def) inputs = batch_data[:batch_meta["input_len"]] if len(inputs) == 3: inputs.append(None) inputs.append(None) inputs.append(task_id) if task_type == TaskType.SeqenceGeneration: # y_idx, #3 -> gen inputs.append(None) inputs.append(3) score = self.mnetwork(*inputs) if task_obj is not None: score, predict = task_obj.test_predict(score) elif task_type == TaskType.Ranking: score = score.contiguous().view(-1, batch_meta["pairwise_size"]) assert task_type == TaskType.Ranking score = F.softmax(score, dim=1) score = score.data.cpu() score = score.numpy() predict = np.zeros(score.shape, dtype=int) positive = np.argmax(score, axis=1) for idx, pos in enumerate(positive): predict[idx, pos] = 1 predict = predict.reshape(-1).tolist() score = score.reshape(-1).tolist() return score, predict, batch_meta["true_label"] elif task_type == TaskType.SeqenceLabeling: mask = batch_data[batch_meta["mask"]] score = score.contiguous() score = score.data.cpu() score = score.numpy() predict = np.argmax(score, axis=1).reshape(mask.size()).tolist() valied_lenght = mask.sum(1).tolist() final_predict = [] for idx, p in enumerate(predict): final_predict.append(p[:valied_lenght[idx]]) score = score.reshape(-1).tolist() return score, final_predict, batch_meta["label"] elif task_type == TaskType.Span or task_type == TaskType.SpanYN: predictions = [] features = [] for idx, offset in enumerate(batch_meta["offset_mapping"]): token_is_max_context = ( batch_meta["token_is_max_context"][idx] if batch_meta.get( "token_is_max_context", None) else None) sample_id = batch_meta["uids"][idx] if "label" in batch_meta: feature = { "offset_mapping": offset, "token_is_max_context": token_is_max_context, "uid": sample_id, "context": batch_meta["context"][idx], "answer": batch_meta["answer"][idx], "label": batch_meta["label"][idx], } else: feature = { "offset_mapping": offset, "token_is_max_context": token_is_max_context, "uid": sample_id, "context": batch_meta["context"][idx], "answer": batch_meta["answer"][idx], } if "null_ans_index" in batch_meta: feature["null_ans_index"] = batch_meta["null_ans_index"] features.append(feature) start, end = score start = start.contiguous() start = start.data.cpu() start = start.numpy().tolist() end = end.contiguous() end = end.data.cpu() end = end.numpy().tolist() return (start, end), predictions, features elif task_type == TaskType.SeqenceGeneration: predicts = self.tokenizer.batch_decode(score, skip_special_tokens=True) predictions = {} golds = {} for idx, predict in enumerate(predicts): sample_id = batch_meta["uids"][idx] answer = batch_meta["answer"][idx] predict = predict.strip() if predict == DUMPY_STRING_FOR_EMPTY_ANS: predict = "" predictions[sample_id] = predict golds[sample_id] = answer score = score.contiguous() score = score.data.cpu() score = score.numpy().tolist() return score, predictions, golds elif task_type == TaskType.ClozeChoice: score = score.contiguous().view(-1) score = score.data.cpu() score = score.numpy() copy_score = score.tolist() answers = batch_meta["answer"] choices = batch_meta["choice"] chunks = batch_meta["pairwise_size"] uids = batch_meta["uids"] predictions = {} golds = {} for chunk in chunks: uid = uids[0] answer = eval(answers[0]) choice = eval(choices[0]) answers = answers[chunk:] choices = choices[chunk:] current_p = score[:chunk] score = score[chunk:] positive = np.argmax(current_p) predict = choice[positive] predictions[uid] = predict golds[uid] = answer return copy_score, predictions, golds else: raise ValueError("Unknown task_type: %s" % task_type) return score, predict, batch_meta["label"]
def collate_fn(self, batch): task_id = batch[0]["task"]["task_id"] task_def = batch[0]["task"]["task_def"] new_batch = [] for sample in batch: assert sample["task"]["task_id"] == task_id assert sample["task"]["task_def"] == task_def new_batch.append(sample["sample"]) task_type = task_def.task_type data_type = task_def.data_type batch = new_batch if task_type == TaskType.Ranking or task_type == TaskType.ClozeChoice: batch, chunk_sizes = self.rebatch(batch) # prepare model input batch_info, batch_data = self._prepare_model_input(batch, data_type) batch_info[ "task_id"] = task_id # used for select correct decoding head batch_info["input_len"] = len( batch_data) # used to select model inputs # select different loss function and other difference in training and testing # DataLoader will convert any unknown type objects to dict, # the conversion logic also convert Enum to repr(Enum), which is a string and undesirable # If we convert object to dict in advance, DataLoader will do nothing batch_info["task_def"] = task_def.__dict__ batch_info[ "pairwise_size"] = self.pairwise_size # need for ranking task # add label labels = [ sample["label"] if "label" in sample else None for sample in batch ] task_obj = tasks.get_task_obj(task_def) if self.is_train: # in training model, label is used by Pytorch, so would be tensor if task_obj is not None: batch_data.append(task_obj.train_prepare_label(labels)) batch_info["label"] = len(batch_data) - 1 elif task_type == TaskType.Ranking or task_type == TaskType.ClozeChoice: batch_data.append(torch.LongTensor(labels)) batch_info["label"] = len(batch_data) - 1 elif task_type == TaskType.Span: # support multi positions start, end = [], [] for sample in batch: if type(sample["start_position"]) is list and type( sample["end_position"]): idx = random.choice( range(0, len(sample["start_position"]))) start.append(sample["start_position"][idx]) end.append(sample["end_position"][idx]) else: start.append(sample["start_position"]) end.append(sample["end_position"]) batch_data.append( (torch.LongTensor(start), torch.LongTensor(end))) # unify to one type of label batch_info["label"] = len(batch_data) - 1 elif task_type == TaskType.SpanYN: # start = [sample['start_position'] for sample in batch] # end = [sample['end_position'] for sample in batch] start, end = [], [] for sample in batch: if type(sample["start_position"]) is list and type( sample["end_position"]): idx = random.choice( range(0, len(sample["start_position"]))) start.append(sample["start_position"][idx]) end.append(sample["end_position"][idx]) else: start.append(sample["start_position"]) end.append(sample["end_position"]) # start, end, yes/no batch_data.append(( torch.LongTensor(start), torch.LongTensor(end), torch.LongTensor(labels), )) # unify to one type of label batch_info["label"] = len(batch_data) - 1 elif task_type == TaskType.SeqenceLabeling: batch_size = self._get_batch_size(batch) tok_len = self._get_max_len(batch, key="token_id") tlab = torch.LongTensor(batch_size, tok_len).fill_(-1) for i, label in enumerate(labels): ll = len(label) tlab[i, :ll] = torch.LongTensor(label) batch_data.append(tlab) batch_info["label"] = len(batch_data) - 1 elif task_type == TaskType.MaskLM: batch_size = self._get_batch_size(batch) tok_len = self._get_max_len(batch, key="token_id") tlab = torch.LongTensor(batch_size, tok_len).fill_(-1) for i, label in enumerate(labels): ll = len(label) tlab[i, :ll] = torch.LongTensor(label) labels = torch.LongTensor( [sample["nsp_lab"] for sample in batch]) batch_data.append((tlab, labels)) batch_info["label"] = len(batch_data) - 1 elif task_type == TaskType.SeqenceGeneration: batch_size = self._get_batch_size(batch) y_idxs = torch.LongTensor( [sample["label"][:-1] for sample in batch]) label = torch.LongTensor( [sample["label"][1:] for sample in batch]) label.masked_fill_(label == 0, -1) batch_data.append(y_idxs) batch_info["y_token_id"] = len(batch_data) - 1 batch_data.append(label) batch_info["label"] = len(batch_data) - 1 # soft label generated by ensemble models for knowledge distillation if self.soft_label_on and "softlabel" in batch[0]: sortlabels = [sample["softlabel"] for sample in batch] sortlabels = task_obj.train_prepare_soft_labels(sortlabels) batch_info["soft_label"] = sortlabels else: # in test model, label would be used for evaluation if task_obj is not None: task_obj.test_prepare_label(batch_info, labels) else: batch_info["label"] = labels if task_type == TaskType.Ranking: batch_info["true_label"] = [ sample["true_label"] for sample in batch ] if task_type == TaskType.ClozeChoice: batch_info["answer"] = [ sample["answer"] for sample in batch ] batch_info["choice"] = [ sample["choice"] for sample in batch ] batch_info["pairwise_size"] = chunk_sizes if task_type == TaskType.Span or task_type == TaskType.SpanYN: batch_info["offset_mapping"] = [ sample["offset_mapping"] for sample in batch ] batch_info["token_is_max_context"] = [ sample.get("token_is_max_context", None) for sample in batch ] batch_info["context"] = [ sample["context"] for sample in batch ] batch_info["answer"] = [ sample["answer"] for sample in batch ] batch_info["label"] = [ sample["label"] if "label" in sample else None for sample in batch ] if task_type == TaskType.SeqenceGeneration: batch_info["answer"] = [ sample["answer"] for sample in batch ] batch_info["uids"] = [sample["uid"] for sample in batch] # used in scoring return batch_info, batch_data
def __init__(self, opt, bert_config=None, initial_from_local=False): super(SANBertNetwork, self).__init__() self.dropout_list = nn.ModuleList() if opt["encoder_type"] not in EncoderModelType._value2member_map_: raise ValueError("encoder_type is out of pre-defined types") self.encoder_type = opt["encoder_type"] self.preloaded_config = None literal_encoder_type = EncoderModelType(self.encoder_type).name.lower() config_class, model_class, _ = MODEL_CLASSES[literal_encoder_type] if not initial_from_local: # self.bert = model_class.from_pretrained(opt['init_checkpoint'], config=self.preloaded_config) self.bert = model_class.from_pretrained( opt["init_checkpoint"], cache_dir=opt["transformer_cache"]) else: self.preloaded_config = config_class.from_dict( opt) # load config from opt self.preloaded_config.output_hidden_states = ( True # return all hidden states ) self.bert = model_class(self.preloaded_config) hidden_size = self.bert.config.hidden_size if opt.get("dump_feature", False): self.config = opt return if opt["update_bert_opt"] > 0: for p in self.bert.parameters(): p.requires_grad = False task_def_list = opt["task_def_list"] self.task_def_list = task_def_list self.decoder_opt = [] self.task_types = [] for task_id, task_def in enumerate(task_def_list): self.decoder_opt.append( generate_decoder_opt(task_def.enable_san, opt["answer_opt"])) self.task_types.append(task_def.task_type) # create output header self.scoring_list = nn.ModuleList() self.dropout_list = nn.ModuleList() for task_id in range(len(task_def_list)): task_def: TaskDef = task_def_list[task_id] lab = task_def.n_class decoder_opt = self.decoder_opt[task_id] task_type = self.task_types[task_id] task_dropout_p = (opt["dropout_p"] if task_def.dropout_p is None else task_def.dropout_p) dropout = DropoutWrapper(task_dropout_p, opt["vb_dropout"]) self.dropout_list.append(dropout) task_obj = tasks.get_task_obj(task_def) if task_obj is not None: # Move this to task_obj self.pooler = Pooler(hidden_size, dropout_p=opt["dropout_p"], actf=opt["pooler_actf"]) out_proj = task_obj.train_build_task_layer(decoder_opt, hidden_size, lab, opt, prefix="answer", dropout=dropout) elif task_type == TaskType.Span: assert decoder_opt != 1 out_proj = nn.Linear(hidden_size, 2) elif task_type == TaskType.SpanYN: assert decoder_opt != 1 out_proj = nn.Linear(hidden_size, 2) elif task_type == TaskType.SeqenceLabeling: out_proj = nn.Linear(hidden_size, lab) # elif task_type == TaskType.MaskLM: # if opt["encoder_type"] == EncoderModelType.ROBERTA: # # TODO: xiaodl # out_proj = MaskLmHeader(self.bert.embeddings.word_embeddings.weight) # else: # out_proj = MaskLmHeader(self.bert.embeddings.word_embeddings.weight) elif task_type == TaskType.SeqenceGeneration: # use orginal header out_proj = None elif task_type == TaskType.ClozeChoice: self.pooler = Pooler(hidden_size, dropout_p=opt["dropout_p"], actf=opt["pooler_actf"]) out_proj = nn.Linear(hidden_size, lab) else: if decoder_opt == 1: out_proj = SANClassifier( hidden_size, hidden_size, lab, opt, prefix="answer", dropout=dropout, ) else: out_proj = nn.Linear(hidden_size, lab) self.scoring_list.append(out_proj) self.config = opt
def forward( self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0, y_input_ids=None, fwd_type=0, embed=None, ): if fwd_type == 3: generated = self.bert.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=self.config["max_answer_len"], num_beams=self.config["num_beams"], repetition_penalty=self.config["repetition_penalty"], length_penalty=self.config["length_penalty"], early_stopping=True, ) return generated elif fwd_type == 2: assert embed is not None last_hidden_state, all_hidden_states = self.encode( None, token_type_ids, attention_mask, embed, y_input_ids) elif fwd_type == 1: return self.embed_encode(input_ids, token_type_ids, attention_mask) else: last_hidden_state, all_hidden_states = self.encode( input_ids, token_type_ids, attention_mask, y_input_ids=y_input_ids) decoder_opt = self.decoder_opt[task_id] task_type = self.task_types[task_id] task_obj = tasks.get_task_obj(self.task_def_list[task_id]) if task_obj is not None: pooled_output = self.pooler(last_hidden_state) logits = task_obj.train_forward( last_hidden_state, pooled_output, premise_mask, hyp_mask, decoder_opt, self.dropout_list[task_id], self.scoring_list[task_id], ) return logits elif task_type == TaskType.Span: assert decoder_opt != 1 last_hidden_state = self.dropout_list[task_id](last_hidden_state) logits = self.scoring_list[task_id](last_hidden_state) start_scores, end_scores = logits.split(1, dim=-1) start_scores = start_scores.squeeze(-1) end_scores = end_scores.squeeze(-1) return start_scores, end_scores elif task_type == TaskType.SpanYN: assert decoder_opt != 1 last_hidden_state = self.dropout_list[task_id](last_hidden_state) logits = self.scoring_list[task_id](last_hidden_state) start_scores, end_scores = logits.split(1, dim=-1) start_scores = start_scores.squeeze(-1) end_scores = end_scores.squeeze(-1) return start_scores, end_scores elif task_type == TaskType.SeqenceLabeling: pooled_output = last_hidden_state pooled_output = self.dropout_list[task_id](pooled_output) pooled_output = pooled_output.contiguous().view( -1, pooled_output.size(2)) logits = self.scoring_list[task_id](pooled_output) return logits elif task_type == TaskType.MaskLM: last_hidden_state = self.dropout_list[task_id](last_hidden_state) logits = self.scoring_list[task_id](last_hidden_state) return logits elif task_type == TaskType.SeqenceGeneration: logits = last_hidden_state.view(-1, last_hidden_state.size(-1)) return logits elif task_type == TaskType.ClozeChoice: pooled_output = self.pooler(last_hidden_state) pooled_output = self.dropout_list[task_id](pooled_output) logits = self.scoring_list[task_id](pooled_output) return logits else: if decoder_opt == 1: max_query = hyp_mask.size(1) assert max_query > 0 assert premise_mask is not None assert hyp_mask is not None hyp_mem = last_hidden_state[:, :max_query, :] logits = self.scoring_list[task_id](last_hidden_state, hyp_mem, premise_mask, hyp_mask) else: pooled_output = self.dropout_list[task_id](pooled_output) logits = self.scoring_list[task_id](pooled_output) return logits
def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0): sequence_output, pooled_output = self.encode(input_ids, token_type_ids, attention_mask) decoder_opt = self.decoder_opt[task_id] task_type = self.task_types[task_id] task_obj = tasks.get_task_obj(self.task_def_list[task_id]) if task_obj is not None: logits = task_obj.train_forward(sequence_output, pooled_output, premise_mask, hyp_mask, decoder_opt, self.dropout_list[task_id], self.scoring_list[task_id]) return logits # Gengyu: label embedding, elif self.opt['label_embedding']: # lookup label embedding, output: label_num(for this task), hidden_size task_label_embeddings = self.label_embedding_layer( self.label_index[task_id]) # output: batch_size, 1, hidden_size pooled_output = pooled_output.unsqueeze(1) # output: batch_size, label_num, hidden_size pooled_output = pooled_output.expand( pooled_output.size()[0], task_label_embeddings.size()[0], pooled_output.size()[2]) emb_cls_mul = torch.mul(task_label_embeddings, pooled_output) logits = torch.sum(emb_cls_mul, -1) return logits elif task_type == TaskType.Span: assert decoder_opt != 1 sequence_output = self.dropout_list[task_id](sequence_output) logits = self.scoring_list[task_id](sequence_output) start_scores, end_scores = logits.split(1, dim=-1) start_scores = start_scores.squeeze(-1) end_scores = end_scores.squeeze(-1) return start_scores, end_scores elif task_type == TaskType.SeqenceLabeling: pooled_output = sequence_output pooled_output = self.dropout_list[task_id](pooled_output) pooled_output = pooled_output.contiguous().view( -1, pooled_output.size(2)) logits = self.scoring_list[task_id](pooled_output) return logits elif task_type == TaskType.MaskLM: sequence_output = self.dropout_list[task_id](sequence_output) logits = self.scoring_list[task_id](sequence_output) return logits else: if decoder_opt == 1: max_query = hyp_mask.size(1) assert max_query > 0 assert premise_mask is not None assert hyp_mask is not None hyp_mem = sequence_output[:, :max_query, :] logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask) else: pooled_output = self.dropout_list[task_id](pooled_output) logits = self.scoring_list[task_id](pooled_output) return logits