Exemplo n.º 1
0
Arquivo: data.py Projeto: cjliux/fsnlu
def get_dataloader_for_support(args, tokenizer, sep_dom=False):
    data_path, fin_data_path = args.data_path, args.fin_data_path
    batch_size = args.batch_size
    if args.load_userdict:
        jieba.load_userdict(args.userdict)

    domain_map = Vocab.from_file(os.path.join(data_path, "domains.txt"))
    intent_map = Vocab.from_file(os.path.join(data_path, "intents.txt"))
    slots_map = Vocab.from_file(os.path.join(data_path, "slots.txt"))
    label_vocab = Vocab.from_file(os.path.join(data_path, "label_vocab.txt"))
    bin_label_vocab = Vocab.from_file(os.path.join(data_path, "bin_label_vocab.txt"))

    sup_dom_data = read_support_data(
        os.path.join(fin_data_path, "support"), 
        tokenizer, domain_map, intent_map, slots_map, 
        label_vocab, bin_label_vocab)

    if not sep_dom:
        sup_data = []
        for dom_data in sup_dom_data.values():
            sup_data.extend(dom_data)
        
        suploader = thdata.DataLoader(
                            dataset=Dataset(sup_data), 
                            batch_size=batch_size, shuffle=True, 
                            collate_fn=collate_fn)
        return suploader
    else:
        suploaders = {}
        for dom, dom_data in sup_dom_data.items():
            suploaders[dom] = thdata.DataLoader(
                            dataset=Dataset(sup_dom_data[dom]), 
                            batch_size=batch_size, shuffle=True,
                            collate_fn=collate_fn)
        return suploaders
Exemplo n.º 2
0
Arquivo: data.py Projeto: cjliux/fsnlu
def get_dataloader_for_train(args, tokenizer):
    data_path, raw_data_path = args.data_path, args.raw_data_path
    batch_size = args.batch_size
    if args.load_userdict:
        jieba.load_userdict(args.userdict)

    domain_map = Vocab.from_file(os.path.join(data_path, "domains.txt"))
    intent_map = Vocab.from_file(os.path.join(data_path, "intents.txt"))
    slots_map = Vocab.from_file(os.path.join(data_path, "slots.txt"))
    label_vocab = Vocab.from_file(os.path.join(data_path, "label_vocab.txt"))
    bin_label_vocab = Vocab.from_file(os.path.join(data_path, "bin_label_vocab.txt"))

    # train 
    all_train_data = [] 
    
    train_dom_data = read_all_train_data(
        os.path.join(raw_data_path, "source.json"), tokenizer,
        domain_map, intent_map, slots_map, label_vocab, bin_label_vocab)

    for dom, dom_data in train_dom_data.items():
        all_train_data.extend(dom_data)

    dev_sup_dom_data = read_support_data(
        os.path.join(raw_data_path, "dev", "support"), 
        tokenizer, domain_map, intent_map, slots_map, 
        label_vocab, bin_label_vocab)

    for i_dom, dom_data in dev_sup_dom_data.items():
        all_train_data.extend(dom_data)

    dataloader = thdata.DataLoader(dataset=Dataset(all_train_data), 
        batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    return dataloader
Exemplo n.º 3
0
Arquivo: data.py Projeto: cjliux/fsnlu
def get_dataloader_for_fs_eval(data_path,
                               raw_data_path,
                               eval_domains: list,
                               batch_size,
                               max_sup_ratio,
                               max_sup_size,
                               n_shots,
                               tokenizer,
                               return_suploader=False):
    domain_map = Vocab.from_file(os.path.join(data_path, "domains.txt"))
    intent_map = Vocab.from_file(os.path.join(data_path, "intents.txt"))
    slots_map = Vocab.from_file(os.path.join(data_path, "slots.txt"))
    label_vocab = Vocab.from_file(os.path.join(data_path, "label_vocab.txt"))
    bin_label_vocab = Vocab.from_file(
        os.path.join(data_path, "bin_label_vocab.txt"))

    # train
    all_train_data = read_all_train_data(
        os.path.join(raw_data_path, "source.json"), tokenizer, domain_map,
        intent_map, slots_map, label_vocab, bin_label_vocab)
    data = {k: v for k, v in all_train_data.items() if k in eval_domains}

    # eval support & query
    fs_data = []
    fs_sup_data = []
    for dom, dom_data in data.items():
        sup_size = max(min(int(max_sup_ratio * len(dom_data)), max_sup_size),
                       n_shots)
        sup_data, qry_data = separate_data_to_support_and_query(
            dom_data, sup_size)
        dom_data = collect_support_instances(sup_data, qry_data, int(n_shots))
        fs_data.extend(dom_data)
        if return_suploader:
            fs_sup_data.extend(sup_data)

    dataloader = thdata.DataLoader(dataset=Dataset(fs_data),
                                   batch_size=batch_size,
                                   shuffle=False,
                                   collate_fn=collate_fn)
    if return_suploader:
        suploader = thdata.DataLoader(dataset=Dataset(fs_sup_data),
                                      batch_size=batch_size,
                                      shuffle=True,
                                      collate_fn=collate_fn)
        return dataloader, suploader
    else:
        return dataloader
Exemplo n.º 4
0
def main():
    data_path = args.data
    vocab_path = args.vocab
    max_length = args.max_length
    out_path = args.out

    token_vocab = Vocab.from_file(path=vocab_path, add_pad=True, add_unk=True)
    trans_vocab = TransVocab()
    label_vocab = LabelVocab()
    data_reader = SNLIReader(data_path=data_path,
                             token_vocab=token_vocab,
                             trans_vocab=trans_vocab,
                             label_vocab=label_vocab,
                             max_length=max_length)
    with open(out_path, 'wb') as f:
        pickle.dump(data_reader, f)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', required=True)
    parser.add_argument('--vocab', required=True)
    parser.add_argument('--vocab-size', required=True, type=int)
    parser.add_argument('--max-length', default=None, type=int)
    parser.add_argument('--binary', default=False, action='store_true')
    parser.add_argument('--out', required=True)
    args = parser.parse_args()

    word_vocab = Vocab.from_file(path=args.vocab, add_pad=True, add_unk=True,
                                 max_size=args.vocab_size)
    data_reader = SSTDataset(
        data_path=args.data, word_vocab=word_vocab, max_length=args.max_length,
        binary=args.binary)
    with open(args.out, 'wb') as f:
        pickle.dump(data_reader, f)
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', required=True)
    parser.add_argument('--vocab', required=True)
    parser.add_argument('--vocab-size', required=True, type=int)
    parser.add_argument('--max-length', required=True, type=int)
    parser.add_argument('--out', required=True)
    args = parser.parse_args()

    word_vocab = Vocab.from_file(path=args.vocab, add_pad=True, add_unk=True,
                                 max_size=args.vocab_size)
    label_dict = {'neutral': 0, 'entailment': 1, 'contradiction': 2}
    label_vocab = Vocab(vocab_dict=label_dict, add_pad=False, add_unk=False)
    data_reader = SNLIDataset(
        data_path=args.data, word_vocab=word_vocab, label_vocab=label_vocab,
        max_length=args.max_length)
    with open(args.out, 'wb') as f:
        pickle.dump(data_reader, f)
Exemplo n.º 7
0
Arquivo: data.py Projeto: cjliux/fsnlu
def get_dataloader_for_fs_test(data_path,
                               raw_data_path,
                               batch_size,
                               n_shots,
                               tokenizer,
                               sep_dom=False,
                               return_suploader=False):
    domain_map = Vocab.from_file(os.path.join(data_path, "domains.txt"))
    intent_map = Vocab.from_file(os.path.join(data_path, "intents.txt"))
    slots_map = Vocab.from_file(os.path.join(data_path, "slots.txt"))
    label_vocab = Vocab.from_file(os.path.join(data_path, "label_vocab.txt"))
    bin_label_vocab = Vocab.from_file(
        os.path.join(data_path, "bin_label_vocab.txt"))

    ## dev support & query
    dev_sup_dom_data, dev_qry_dom_data = read_dev_support_and_query_data(
        os.path.join(raw_data_path, "dev"), tokenizer, domain_map, intent_map,
        slots_map, label_vocab, bin_label_vocab)

    if not sep_dom:
        fs_data = []
        fs_sup_data = []
        for dom in dev_sup_dom_data.keys():
            dom_data = collect_support_instances(dev_sup_dom_data[dom],
                                                 dev_qry_dom_data[dom],
                                                 int(n_shots))
            fs_data.extend(dom_data)
            if return_suploader:
                fs_sup_data.extend(dev_sup_dom_data[dom])

        dataloader = thdata.DataLoader(dataset=Dataset(fs_data),
                                       batch_size=batch_size,
                                       shuffle=False,
                                       collate_fn=collate_fn)
        if return_suploader:
            suploader = thdata.DataLoader(dataset=Dataset(fs_sup_data),
                                          batch_size=batch_size,
                                          shuffle=True,
                                          collate_fn=collate_fn)
            return dataloader, suploader
        else:
            return dataloader
    else:
        dataloaders = {}
        suploaders = {}
        for dom in dev_sup_dom_data.keys():
            dom_data = collect_support_instances(dev_sup_dom_data[dom],
                                                 dev_qry_dom_data[dom],
                                                 int(n_shots))

            dataloaders[dom] = thdata.DataLoader(dataset=Dataset(dom_data),
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 collate_fn=collate_fn)
            if return_suploader:
                suploaders[dom] = thdata.DataLoader(dataset=Dataset(
                    dev_sup_dom_data[dom]),
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    collate_fn=collate_fn)
        if return_suploader:
            return dataloaders, suploaders
        else:
            return dataloaders
Exemplo n.º 8
0
    def __init__(self, args, tokenizer=None):
        super().__init__()
        self.args = args

        self.domain_map = Vocab.from_file(
            os.path.join(args.data_path, "domains.txt"))
        self.intent_map = Vocab.from_file(
            os.path.join(args.data_path, "intents.txt"))
        self.slots_map = Vocab.from_file(
            os.path.join(args.data_path, "slots.txt"))
        self.label_vocab = Vocab.from_file(
            os.path.join(args.data_path, "label_vocab.txt"))
        self.bin_label_vocab = Vocab.from_file(
            os.path.join(args.data_path, "bin_label_vocab.txt"))
        with open(os.path.join(args.data_path, "dom2intents.json"),
                  'r',
                  encoding='utf8') as fd:
            self.dom2intents = json.load(fd)
        with open(os.path.join(args.data_path, "dom2slots.json"),
                  'r',
                  encoding='utf8') as fd:
            self.dom2slots = json.load(fd)

        dom_int_mask = {}
        for i_dom, dom in enumerate(self.domain_map._vocab):
            dom_int_mask[i_dom] = torch.ByteTensor([
                0 if self.intent_map.index2word[i] in self.dom2intents[dom]
                else 1 for i in range(self.intent_map.n_words)
            ]).cuda()
        self.dom_int_mask = dom_int_mask

        dom_label_mask = {}
        for i_dom, dom in enumerate(self.domain_map._vocab):
            cand_labels = [self.label_vocab.word2index['O']]
            for sl in self.dom2slots[dom]:
                cand_labels.extend([
                    self.label_vocab.word2index['B-' + sl],
                    self.label_vocab.word2index['I-' + sl]
                ])
            dom_label_mask[i_dom] = torch.LongTensor([
                0 if i in cand_labels else 1
                for i in range(self.label_vocab.n_words)
            ]).byte().cuda()
        self.dom_label_mask = dom_label_mask

        self.tokenizer = tokenizer

        self.bert_enc = BertModel.from_pretrained(
            pretrained_model_path=os.path.join(args.bert_dir,
                                               "pytorch_model.bin"),
            config_path=os.path.join(args.bert_dir, "bert_config.json"))

        self.domain_outputs = nn.Linear(self.bert_enc.config.hidden_size,
                                        self.domain_map.n_words)
        self.intent_outputs = nn.Linear(self.bert_enc.config.hidden_size,
                                        self.intent_map.n_words)

        self.sltype_outputs = nn.Linear(self.bert_enc.config.hidden_size,
                                        self.slots_map.n_words + 1)
        self.bio_outputs = nn.Linear(self.bert_enc.config.hidden_size, 3)

        self.sltype_map = torch.LongTensor([
            self.slots_map.word2index[lbl[2:]] + 1 if lbl != 'O' else 0
            for lbl in self.label_vocab._vocab
        ]).cuda()
        m = {'O': 0, 'B': 1, 'I': 2}
        self.bio_map = torch.LongTensor(
            [m[lbl[0]] for lbl in self.label_vocab._vocab]).cuda()

        self.dropout = nn.Dropout(p=0.1)

        self.loss_fct = nn.CrossEntropyLoss(reduction='none')

        self.crf_layer = CRF(self.label_vocab)