Ejemplo n.º 1
0
    def __init__(self, model_file, glove_embs, glove_dict, glove_ver, negative,
                 **kwargs):

        self.negative = negative
        self.model_file = model_file
        self.glove_ver = glove_ver

        self.words = glove_dict
        self.w_to_i = {v: k for (k, v) in enumerate(self.words)}

        self.not_a_word_Word = self.w_to_i[not_a_word_Str]

        self.glove_embs = glove_embs
        self.voc_size = self.glove_embs.shape[0]

        self.chars = [unichr(i) for i in range(128)]
        self.c_to_i = {v: k for (k, v) in list(enumerate(self.chars))}

        self.qa_net = QANet(self.voc_size,
                            emb_init=self.glove_embs,
                            skip_train_fn=True,
                            negative=self.negative,
                            **kwargs)

        self.qa_net.load_params(self.model_file)
Ejemplo n.º 2
0
print("Loading data...")

glove_words, glove_embs = load_glove(args.glove)
voc_size = glove_embs.shape[0]

NAW_token = glove_words.index('<not_a_word>')

train_data = load_squad_train(squad_prep_path,
                              negative_paths=args.negative,
                              NAW_token=NAW_token)
train_data = filter_empty_answers(train_data)
train_data = trim_data(train_data, 300)

dev_data = load_squad_dev(args.squad,
                          squad_prep_path,
                          NAW_token=NAW_token,
                          lower_raw=True,
                          make_negative=bool(args.negative))

net = QANet(voc_size=voc_size,
            emb_init=glove_embs,
            dev_data=dev_data,
            predictions_path=preds_path,
            train_unk=True,
            negative=bool(args.negative),
            init_lrate=args.learning_rate,
            checkpoint_examples=args.checkpoint,
            conv='valid')

train_QANet(net, train_data, args.output_dir, batch_size=args.batch_size)
Ejemplo n.º 3
0
class AnswerBot:
    def __init__(self, model_file, glove_embs, glove_dict, negative, **kwargs):

        self.negative = negative
        self.model_file = model_file

        self.words = glove_dict
        self.w_to_i = {v: k for (k, v) in enumerate(self.words)}

        self.not_a_word_Word = self.w_to_i[not_a_word_Str]

        self.glove_embs = glove_embs
        self.voc_size = self.glove_embs.shape[0]

        self.chars = [unichr(i) for i in range(128)]
        self.c_to_i = {v: k for (k, v) in list(enumerate(self.chars))}

        self.qa_net = QANet(self.voc_size,
                            emb_init=self.glove_embs,
                            skip_train_fn=True,
                            train_unk=True,
                            negative=self.negative,
                            **kwargs)

        self.qa_net.load_params(self.model_file)

    def to_nums(self, ws):
        return [self.w_to_i.get(w, 0) for w in ws]

    def to_chars(self, w):
        return [1] + [self.c_to_i.get(c, 0) for c in w] + [2]

    def q_to_num(self, q):
        assert isinstance(q, list)
        return self.to_nums(q), map(self.to_chars, q)

    def prepare_question(self, q, x, q_num, q_char):
        assert isinstance(q, list)
        assert isinstance(x, list)

        x_num = self.to_nums(x)
        x_char = map(self.to_chars, x)

        def make_bin_feats(q, x):
            qset = set(q)
            return [w in qset for w in x]

        words = [[], q_num, x_num]
        chars = [q_char, x_char]
        bin_feats = make_bin_feats(q, x)

        if self.negative:
            words[2].append(self.not_a_word_Word)
            chars[1].append([1, not_a_word_Char, 2])
            bin_feats.append(False)
            x = x + [not_a_word_Str]

        return x, words, chars, bin_feats

    def get_answers(self, question, contexts, contexts_cased, beam=1):
        if not contexts:
            return []
        num_contexts = len(contexts)
        assert len(contexts_cased) == num_contexts

        q_words, q_chars = self.q_to_num(question)

        xs = []
        data = [[], [], []]

        for x in contexts:
            x, words, chars, bin_feats = self.prepare_question(
                question, x, q_words, q_chars)
            xs.append(x)
            data[0].append(words)
            data[1].append(chars)
            data[2].append(bin_feats)

        l, r, scr = self.qa_net._predict_spans(data,
                                               beam=beam,
                                               batch_size=num_contexts)

        answers = []
        all_contexts = u' '.join(
            [u' '.join(c) if type(c) is list else c for c in contexts_cased])
        all_contexts_lower = lower_if_needed([[all_contexts]])[0][0]

        for i in range(num_contexts):
            answer = xs[i][l[i]:r[i] + 1]
            # Try to retrieve the answer in original case
            answer_str = u' '.join(answer)
            pos = all_contexts_lower.find(answer_str)
            if pos != -1:
                answer = all_contexts[pos:pos + len(answer_str)].split(' ')
            answers.append((answer, scr[i]))

        return answers
Ejemplo n.º 4
0
def train_entry():
    '''
    训练
    :return:
    '''
    with open(Config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(Config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(Config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(Config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)

    print("Building model...")

    train_dataset = SQuADDataset(Config.train_record_file, Config.num_steps, Config.batch_size)
    dev_dataset = SQuADDataset(Config.dev_record_file, Config.test_num_batches, Config.batch_size)

    lr = Config.learning_rate
    base_lr = 1.0
    warm_up = Config.lr_warm_up_num

    model = QANet(word_mat, char_mat).to(Config.device)
    ema = EMA(Config.ema_decay)  # 指数平均移动

    for name, p in model.named_parameters():
        if p.requires_grad:
            ema.set(name, p)

    # 取出模型中的所有参数 然后去定义优化器
    params = filter(lambda param: param.requires_grad, model.parameters())
    optimizer = optim.Adam(lr=base_lr, betas=(Config.beta1, Config.beta2), eps=1e-7, weight_decay=3e-7, params=params)

    # 学习率进行调整
    cr = lr / log2(warm_up)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ee: cr * log2(ee + 1) if ee < warm_up else lr)

    L = Config.checkpoint
    N = Config.num_steps

    # 两种评价指标
    best_f1 = best_em = patience = 0
    for iter in range(0, N, L):
        train(model, optimizer, scheduler, ema, train_dataset, iter, L)
        valid(model, train_dataset, train_eval_file)
        # 测试数据
        metrics = test(model, dev_dataset, dev_eval_file)
        print("Learning rate: {}".format(scheduler.get_lr()))
        dev_f1 = metrics["f1"]
        dev_em = metrics["exact_match"]
        if dev_f1 < best_f1 and dev_em < best_em:
            patience += 1
            if patience > Config.early_stop:
                break
        else:
            patience = 0
            best_f1 = max(best_f1, dev_f1)
            best_em = max(best_em, dev_em)

        fn = os.path.join(Config.save_dir, "model.pt")
        torch.save(model, fn)
Ejemplo n.º 5
0
class AnswerBot:
    def __init__(self, model_file, glove_embs, glove_dict, glove_ver, negative,
                 **kwargs):

        self.negative = negative
        self.model_file = model_file
        self.glove_ver = glove_ver

        self.words = glove_dict
        self.w_to_i = {v: k for (k, v) in enumerate(self.words)}

        self.not_a_word_Word = self.w_to_i[not_a_word_Str]

        self.glove_embs = glove_embs
        self.voc_size = self.glove_embs.shape[0]

        self.chars = [unichr(i) for i in range(128)]
        self.c_to_i = {v: k for (k, v) in list(enumerate(self.chars))}

        self.qa_net = QANet(self.voc_size,
                            emb_init=self.glove_embs,
                            skip_train_fn=True,
                            negative=self.negative,
                            **kwargs)

        self.qa_net.load_params(self.model_file)

    def prepare_question(self, q, x):
        assert type(q) is type(x)
        assert type(q) in [str, unicode, list]

        def to_nums(ws):
            return [self.w_to_i.get(w, 0) for w in ws]

        def to_chars(w):
            return [1] + [self.c_to_i.get(c, 0) for c in w] + [2]

        def make_words(q, x):
            return [[], to_nums(q), to_nums(x)]

        def make_chars(q, x):
            return [map(to_chars, q), map(to_chars, x)]

        def make_bin_feats(q, x):
            qset = set(q)
            return [w in qset for w in x]

        def lower_if_needed(l):
            if self.glove_ver == '6B':
                return [w.lower() for w in l]
            return l

        if type(q) is not list:
            q = lower_if_needed(tokenize(q))
            x = lower_if_needed(tokenize(x))

        neg = self.negative
        if neg and x[-1] == not_a_word_Str:
            neg = False

        data = make_words(q, x), make_chars(q, x), make_bin_feats(q, x)

        if neg:
            data[0][2].append(self.not_a_word_Word)
            data[1][1].append([1, not_a_word_Char, 2])
            data[2].append(False)
            x.append(not_a_word_Str)

        return (q, x) + data

    def get_answers(self, questions, contexts, beam=1):
        num_contexts = len(contexts)
        assert len(questions) == num_contexts

        sample = []
        data = [[], [], []]

        for i in range(num_contexts):
            q = questions[i]
            x = contexts[i]
            q, x, words, chars, bin_feats = self.prepare_question(q, x)
            sample.append([q, x])
            data[0].append(words)
            data[1].append(chars)
            data[2].append(bin_feats)

        l, r, scr = self.qa_net._predict_spans(data,
                                               beam=beam,
                                               batch_size=num_contexts)

        answers = []
        for i in range(num_contexts):
            answer = sample[i][1][l[i]:r[i] + 1]
            answers.append((answer, scr[i]))

        return answers