def get_dataloader_for_support(args, tokenizer, sep_dom=False): data_path, fin_data_path = args.data_path, args.fin_data_path batch_size = args.batch_size if args.load_userdict: jieba.load_userdict(args.userdict) domain_map = Vocab.from_file(os.path.join(data_path, "domains.txt")) intent_map = Vocab.from_file(os.path.join(data_path, "intents.txt")) slots_map = Vocab.from_file(os.path.join(data_path, "slots.txt")) label_vocab = Vocab.from_file(os.path.join(data_path, "label_vocab.txt")) bin_label_vocab = Vocab.from_file(os.path.join(data_path, "bin_label_vocab.txt")) sup_dom_data = read_support_data( os.path.join(fin_data_path, "support"), tokenizer, domain_map, intent_map, slots_map, label_vocab, bin_label_vocab) if not sep_dom: sup_data = [] for dom_data in sup_dom_data.values(): sup_data.extend(dom_data) suploader = thdata.DataLoader( dataset=Dataset(sup_data), batch_size=batch_size, shuffle=True, collate_fn=collate_fn) return suploader else: suploaders = {} for dom, dom_data in sup_dom_data.items(): suploaders[dom] = thdata.DataLoader( dataset=Dataset(sup_dom_data[dom]), batch_size=batch_size, shuffle=True, collate_fn=collate_fn) return suploaders
def get_dataloader_for_train(args, tokenizer): data_path, raw_data_path = args.data_path, args.raw_data_path batch_size = args.batch_size if args.load_userdict: jieba.load_userdict(args.userdict) domain_map = Vocab.from_file(os.path.join(data_path, "domains.txt")) intent_map = Vocab.from_file(os.path.join(data_path, "intents.txt")) slots_map = Vocab.from_file(os.path.join(data_path, "slots.txt")) label_vocab = Vocab.from_file(os.path.join(data_path, "label_vocab.txt")) bin_label_vocab = Vocab.from_file(os.path.join(data_path, "bin_label_vocab.txt")) # train all_train_data = [] train_dom_data = read_all_train_data( os.path.join(raw_data_path, "source.json"), tokenizer, domain_map, intent_map, slots_map, label_vocab, bin_label_vocab) for dom, dom_data in train_dom_data.items(): all_train_data.extend(dom_data) dev_sup_dom_data = read_support_data( os.path.join(raw_data_path, "dev", "support"), tokenizer, domain_map, intent_map, slots_map, label_vocab, bin_label_vocab) for i_dom, dom_data in dev_sup_dom_data.items(): all_train_data.extend(dom_data) dataloader = thdata.DataLoader(dataset=Dataset(all_train_data), batch_size=batch_size, shuffle=True, collate_fn=collate_fn) return dataloader
def get_dataloader_for_fs_eval(data_path, raw_data_path, eval_domains: list, batch_size, max_sup_ratio, max_sup_size, n_shots, tokenizer, return_suploader=False): domain_map = Vocab.from_file(os.path.join(data_path, "domains.txt")) intent_map = Vocab.from_file(os.path.join(data_path, "intents.txt")) slots_map = Vocab.from_file(os.path.join(data_path, "slots.txt")) label_vocab = Vocab.from_file(os.path.join(data_path, "label_vocab.txt")) bin_label_vocab = Vocab.from_file( os.path.join(data_path, "bin_label_vocab.txt")) # train all_train_data = read_all_train_data( os.path.join(raw_data_path, "source.json"), tokenizer, domain_map, intent_map, slots_map, label_vocab, bin_label_vocab) data = {k: v for k, v in all_train_data.items() if k in eval_domains} # eval support & query fs_data = [] fs_sup_data = [] for dom, dom_data in data.items(): sup_size = max(min(int(max_sup_ratio * len(dom_data)), max_sup_size), n_shots) sup_data, qry_data = separate_data_to_support_and_query( dom_data, sup_size) dom_data = collect_support_instances(sup_data, qry_data, int(n_shots)) fs_data.extend(dom_data) if return_suploader: fs_sup_data.extend(sup_data) dataloader = thdata.DataLoader(dataset=Dataset(fs_data), batch_size=batch_size, shuffle=False, collate_fn=collate_fn) if return_suploader: suploader = thdata.DataLoader(dataset=Dataset(fs_sup_data), batch_size=batch_size, shuffle=True, collate_fn=collate_fn) return dataloader, suploader else: return dataloader
def main(): data_path = args.data vocab_path = args.vocab max_length = args.max_length out_path = args.out token_vocab = Vocab.from_file(path=vocab_path, add_pad=True, add_unk=True) trans_vocab = TransVocab() label_vocab = LabelVocab() data_reader = SNLIReader(data_path=data_path, token_vocab=token_vocab, trans_vocab=trans_vocab, label_vocab=label_vocab, max_length=max_length) with open(out_path, 'wb') as f: pickle.dump(data_reader, f)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', required=True) parser.add_argument('--vocab', required=True) parser.add_argument('--vocab-size', required=True, type=int) parser.add_argument('--max-length', default=None, type=int) parser.add_argument('--binary', default=False, action='store_true') parser.add_argument('--out', required=True) args = parser.parse_args() word_vocab = Vocab.from_file(path=args.vocab, add_pad=True, add_unk=True, max_size=args.vocab_size) data_reader = SSTDataset( data_path=args.data, word_vocab=word_vocab, max_length=args.max_length, binary=args.binary) with open(args.out, 'wb') as f: pickle.dump(data_reader, f)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', required=True) parser.add_argument('--vocab', required=True) parser.add_argument('--vocab-size', required=True, type=int) parser.add_argument('--max-length', required=True, type=int) parser.add_argument('--out', required=True) args = parser.parse_args() word_vocab = Vocab.from_file(path=args.vocab, add_pad=True, add_unk=True, max_size=args.vocab_size) label_dict = {'neutral': 0, 'entailment': 1, 'contradiction': 2} label_vocab = Vocab(vocab_dict=label_dict, add_pad=False, add_unk=False) data_reader = SNLIDataset( data_path=args.data, word_vocab=word_vocab, label_vocab=label_vocab, max_length=args.max_length) with open(args.out, 'wb') as f: pickle.dump(data_reader, f)
def get_dataloader_for_fs_test(data_path, raw_data_path, batch_size, n_shots, tokenizer, sep_dom=False, return_suploader=False): domain_map = Vocab.from_file(os.path.join(data_path, "domains.txt")) intent_map = Vocab.from_file(os.path.join(data_path, "intents.txt")) slots_map = Vocab.from_file(os.path.join(data_path, "slots.txt")) label_vocab = Vocab.from_file(os.path.join(data_path, "label_vocab.txt")) bin_label_vocab = Vocab.from_file( os.path.join(data_path, "bin_label_vocab.txt")) ## dev support & query dev_sup_dom_data, dev_qry_dom_data = read_dev_support_and_query_data( os.path.join(raw_data_path, "dev"), tokenizer, domain_map, intent_map, slots_map, label_vocab, bin_label_vocab) if not sep_dom: fs_data = [] fs_sup_data = [] for dom in dev_sup_dom_data.keys(): dom_data = collect_support_instances(dev_sup_dom_data[dom], dev_qry_dom_data[dom], int(n_shots)) fs_data.extend(dom_data) if return_suploader: fs_sup_data.extend(dev_sup_dom_data[dom]) dataloader = thdata.DataLoader(dataset=Dataset(fs_data), batch_size=batch_size, shuffle=False, collate_fn=collate_fn) if return_suploader: suploader = thdata.DataLoader(dataset=Dataset(fs_sup_data), batch_size=batch_size, shuffle=True, collate_fn=collate_fn) return dataloader, suploader else: return dataloader else: dataloaders = {} suploaders = {} for dom in dev_sup_dom_data.keys(): dom_data = collect_support_instances(dev_sup_dom_data[dom], dev_qry_dom_data[dom], int(n_shots)) dataloaders[dom] = thdata.DataLoader(dataset=Dataset(dom_data), batch_size=batch_size, shuffle=False, collate_fn=collate_fn) if return_suploader: suploaders[dom] = thdata.DataLoader(dataset=Dataset( dev_sup_dom_data[dom]), batch_size=batch_size, shuffle=True, collate_fn=collate_fn) if return_suploader: return dataloaders, suploaders else: return dataloaders
def __init__(self, args, tokenizer=None): super().__init__() self.args = args self.domain_map = Vocab.from_file( os.path.join(args.data_path, "domains.txt")) self.intent_map = Vocab.from_file( os.path.join(args.data_path, "intents.txt")) self.slots_map = Vocab.from_file( os.path.join(args.data_path, "slots.txt")) self.label_vocab = Vocab.from_file( os.path.join(args.data_path, "label_vocab.txt")) self.bin_label_vocab = Vocab.from_file( os.path.join(args.data_path, "bin_label_vocab.txt")) with open(os.path.join(args.data_path, "dom2intents.json"), 'r', encoding='utf8') as fd: self.dom2intents = json.load(fd) with open(os.path.join(args.data_path, "dom2slots.json"), 'r', encoding='utf8') as fd: self.dom2slots = json.load(fd) dom_int_mask = {} for i_dom, dom in enumerate(self.domain_map._vocab): dom_int_mask[i_dom] = torch.ByteTensor([ 0 if self.intent_map.index2word[i] in self.dom2intents[dom] else 1 for i in range(self.intent_map.n_words) ]).cuda() self.dom_int_mask = dom_int_mask dom_label_mask = {} for i_dom, dom in enumerate(self.domain_map._vocab): cand_labels = [self.label_vocab.word2index['O']] for sl in self.dom2slots[dom]: cand_labels.extend([ self.label_vocab.word2index['B-' + sl], self.label_vocab.word2index['I-' + sl] ]) dom_label_mask[i_dom] = torch.LongTensor([ 0 if i in cand_labels else 1 for i in range(self.label_vocab.n_words) ]).byte().cuda() self.dom_label_mask = dom_label_mask self.tokenizer = tokenizer self.bert_enc = BertModel.from_pretrained( pretrained_model_path=os.path.join(args.bert_dir, "pytorch_model.bin"), config_path=os.path.join(args.bert_dir, "bert_config.json")) self.domain_outputs = nn.Linear(self.bert_enc.config.hidden_size, self.domain_map.n_words) self.intent_outputs = nn.Linear(self.bert_enc.config.hidden_size, self.intent_map.n_words) self.sltype_outputs = nn.Linear(self.bert_enc.config.hidden_size, self.slots_map.n_words + 1) self.bio_outputs = nn.Linear(self.bert_enc.config.hidden_size, 3) self.sltype_map = torch.LongTensor([ self.slots_map.word2index[lbl[2:]] + 1 if lbl != 'O' else 0 for lbl in self.label_vocab._vocab ]).cuda() m = {'O': 0, 'B': 1, 'I': 2} self.bio_map = torch.LongTensor( [m[lbl[0]] for lbl in self.label_vocab._vocab]).cuda() self.dropout = nn.Dropout(p=0.1) self.loss_fct = nn.CrossEntropyLoss(reduction='none') self.crf_layer = CRF(self.label_vocab)