コード例 #1
0
def process_squad_file(data, word_counter, char_counter):
    print("Generating examples...")
    examples = []
    eval_examples = {}
    total,_i_para  = 0, 0
    questions = []
    paragraphs = []
    question_to_paragraph = []
    for article in tqdm(data["data"]):
        title = article["title"]
        for para in article["paragraphs"]:
            context = para["context"].replace(
                "''", '" ').replace("``", '" ')
            paragraphs.append(context)
            context_tokens = UTIL.word_tokenize(context)
            context_chars = [list(token) for token in context_tokens]
            spans = convert_idx(context, context_tokens)
            for token in context_tokens:
                word_counter[token] += len(para["qas"])
                for char in token:
                    char_counter[char] += len(para["qas"])
            for qa in para["qas"]:
                total += 1
                ques = qa["question"].replace(
                    "''", '" ').replace("``", '" ')
                questions.append(ques)
                question_to_paragraph.append(_i_para)
                ques_tokens = UTIL.word_tokenize(ques)
                ques_chars = [list(token) for token in ques_tokens]
                for token in ques_tokens:
                    word_counter[token] += 1
                    for char in token:
                        char_counter[char] += 1
                y1s, y2s = [], []
                answer_texts = []
                for answer in qa["answers"]:
                    answer_text = answer["text"]
                    answer_start = answer['answer_start']
                    answer_end = answer_start + len(answer_text)
                    answer_texts.append(answer_text)
                    answer_span = []
                    for idx, span in enumerate(spans):
                        if not (answer_end <= span[0] or answer_start >= span[1]):
                            answer_span.append(idx)
                    y1, y2 = answer_span[0], answer_span[-1]
                    y1s.append(y1)
                    y2s.append(y2)
                example = {"context_tokens": context_tokens, "context_chars": context_chars, "ques_tokens": ques_tokens,
                           "ques_chars": ques_chars, "y1s": y1s, "y2s": y2s, "id": total}
                examples.append(example)
                eval_examples[str(total)] = {
                    "context": context, "spans": spans, 'ques': ques,"answers": answer_texts, "uuid": qa["id"], 'title': title}
            _i_para += 1
    print("{} questions in total".format(len(examples)))
    return examples, eval_examples, questions, paragraphs, question_to_paragraph
コード例 #2
0
    def train(self, real_utts, fake_utts, lfs):
        lfs_raw = lfs
        utts_raw = []
        utts_indexed = []
        lfs_indexed = []
        for real, fake, lf in zip(real_utts, fake_utts, lfs):
            real_indexed = self.vocab.encode(util.word_tokenize(real))
            fake_indexed = self.vocab.encode(util.word_tokenize(fake))
            lf_indexed = self.vocab.encode(util.lf_tokenize(lf))
            if real not in utts_raw:
                utts_raw.append(real)
                utts_indexed.append(real_indexed)
                lfs_indexed.append(lf_indexed)
            if fake not in utts_raw:
                utts_raw.append(fake)
                utts_indexed.append(fake_indexed)
                lfs_indexed.append(lf_indexed)

        opt = optim.Adam(self.parameters(), lr=0.0003)
        opt_sched = optim.lr_scheduler.StepLR(opt,
                                              step_size=FLAGS.train_iters // 2,
                                              gamma=0.1)
        total_loss = 0
        self.implementation.train()
        for i in range(FLAGS.train_iters):
            if (i + 1) % 10 == 0:
                print("{:.3f}".format(total_loss / 10), file=sys.stderr)
                sys.stderr.flush()
                total_loss = 0

            indices = np.random.randint(len(utts_indexed),
                                        size=FLAGS.batch_size)
            batch_utts_raw = [utts_raw[i] for i in indices]
            batch_utts_indexed = [utts_indexed[i] for i in indices]
            batch_utt_data = batch_seqs(batch_utts_indexed).to(self.device)

            lfs = [lfs_indexed[i] for i in indices]
            lf_data = batch_seqs(lfs).to(self.device)
            lf_ctx = lf_data[:-1, :]
            lf_tgt = lf_data[1:, :].view(-1)

            logits = self.implementation(batch_utts_raw, batch_utt_data,
                                         lf_ctx)
            logits = logits.view(-1, logits.shape[-1])
            loss = self.loss(logits, lf_tgt)

            opt.zero_grad()
            loss.backward()
            opt.step()
            opt_sched.step()
            total_loss += loss.item()
コード例 #3
0
def prepare_covariates(df, 
    stopwords=None,
    vocab_size=2000,
    use_counts=False):

    def admissable(w):
        if stopwords is None:
            return True
        return w not in stopwords

    # 2k most common not in lex
    c = Counter([w for s in df['text'] for w in util.word_tokenize(s.lower()) if admissable(w)])
    vocab = list(zip(*c.most_common(vocab_size)))[0]

    # vectorize inputs
    vectorizer = feature_extraction.text.CountVectorizer(
        lowercase=True,
        tokenizer=util.word_tokenize,
        vocabulary=vocab,
        binary=(not use_counts),
        ngram_range=(1, 1))
    corpus = list(df['text'])
    vectorizer.fit(corpus)
    X = vectorizer.transform(corpus).todense()
    return X, vocab, vectorizer
コード例 #4
0
    def represent(utt):
        out = []
        utt_words = util.word_tokenize(utt)
        utt_enc = torch.tensor([tokenizer.encode(utt)]).to(_device())
        if FLAGS.bert_features:
            with torch.no_grad():
                _, _, hiddens = representer(utt_enc)
            out.append(hiddens[0])
            out.append(hiddens[-1])
        
        if FLAGS.lex_features:
            one_hot = torch.zeros(1, utt_enc.shape[1], len(vocab))
            j = 0
            for i in range(len(utt_enc)):
                dec = tokenizer.decode(utt_enc[i])
                if not dec.startswith("##"):
                    word = utt_words[j]
                    if word in vocab:
                        one_hot[0, i, vocab[word]] = 1
                    j += 1
            one_hot = one_hot.to(_device())
            out.append(one_hot)

        if len(out) == 1:
            return out[0].detach()
        else:
            return torch.cat(out, dim=2).detach()
コード例 #5
0
    def preprocess(self, path, draft):
        output = []
        stopwords = [' ', '\n', '\u3000', '\u202f', '\u2009']

        with open(path, 'r', encoding='utf-8') as f:
            
            with open(path, 'r', encoding = 'utf-8') as t:
                data = []
                for line in t:
                    data.append(json.loads(line))
                t.close()
            # pdb.set_trace()
            if draft:
                data[0]['data'] = data[0]['data'][:1]

            for topic in data[0]['data']:
                for paragraph in topic['paragraphs']:
                    context = paragraph['context']
                    tokens = word_tokenize(context)
                    for qa in paragraph['qas']:
                        qid = qa['id']
                        question = qa['question']
                        for ans in qa['answers']:
                            answer = ans['text']
                            s_idx = ans['answer_start']
                            e_idx = s_idx + len(answer)

                            l = 0
                            s_found = False
                            for i, t in enumerate(tokens):
                                while l < len(context):
                                    if context[l] in stopwords:
                                        l += 1
                                    else:
                                        break
                                if t[0] == '"' and context[l:l + 2] == '\'\'':
                                    t = '\'\'' + t[1:]
                                elif t == '"' and context[l:l + 2] == '\'\'':
                                    t = '\'\''

                                l += len(t)
                                if l > s_idx and s_found == False:
                                    s_idx = i
                                    s_found = True
                                if l >= e_idx:
                                    e_idx = i
                                    break

                            output.append(dict([('qid', qid),
                                                ('context', context),
                                                ('question', question),
                                                ('answer', answer),
                                                ('start_idx', s_idx),
                                                ('end_idx', e_idx)]))
                
        with open('{}l'.format(path), 'w', encoding='utf-8') as f:
            for line in output:
                json.dump(line, f)
                print('', file=f)
コード例 #6
0
def main(argv):
    canonical_utt_file = os.path.join(FLAGS.data_dir, "genovernight.out", FLAGS.dataset, "utterances_formula.tsv")
    train_file = os.path.join(FLAGS.data_dir, "data", "{}.paraphrases.train.examples".format(FLAGS.dataset))

    vocab = {}
    with open(train_file) as f:
        train_str = f.read()
        train_data = sexpdata.loads("({})".format(train_str))
        for datum in train_data:
            real = datum[1][1]
            words = util.word_tokenize(real)
            for word in words:
                if word not in vocab:
                    vocab[word] = len(vocab)
    with open(canonical_utt_file) as f:
        for line in f:
            utt, _ = line.strip().split("\t")
            words = util.word_tokenize(utt)
            for word in words:
                if word not in vocab:
                    vocab[word] = len(vocab)

    sent_representer = _sent_representer(vocab)
    word_representer = _word_representer(vocab)

    sent_reps = []
    word_reps = []
    utts = []
    lfs = []
    with open(canonical_utt_file) as f:
        for line in tqdm(f):
            utt, lf = line.strip().split("\t")
            sent_reps.append(sent_representer(utt).squeeze(0).detach().cpu().numpy())
            word_reps.append(word_representer(utt).detach().cpu().numpy())
            utts.append(utt)
            lfs.append(lf)

    with open(FLAGS.write_vocab, "w") as f:
        json.dump(vocab, f)
    np.save(FLAGS.write_utt_reps, sent_reps)
    np.save(FLAGS.write_word_reps, _pad_cat(word_reps))
    with open(FLAGS.write_utts, "w") as f:
        json.dump(utts, f)
    with open(FLAGS.write_lfs, "w") as f:
        json.dump(lfs, f)
コード例 #7
0
def anwer_range_to_span_index(context, ranges):
    """
    :param context: The context from the story, containing the answer
    :param answer_token_ranges: The index ranges mapping to the part of the context containg
                                the answer. It is a string, parsing needed
    :return: index pointing to the part of the context where the answer span starts.
    NewsQA stores the ranges as indexes over the tokenized context, SQuAD does it over 
    the characters index.
    """
    context_tokens = UTIL.word_tokenize(context)
    span_text = ' '.join(context_tokens[ranges[0]:ranges[1]])
    span_start = len(' '.join(context_tokens[:ranges[0]])) + 1
    span_end = span_start + len(span_text)
    return span_start, span_end
コード例 #8
0
 def predict(self, utt, gold_lf):
     self.implementation.eval()
     utt_raw = [utt]
     utt_data = batch_seqs(
         [self.vocab.encode(util.word_tokenize(utt),
                            unk=True)]).to(self.device)
     preds = self.implementation.predict(utt_raw, utt_data)
     if len(preds) == 0:
         return None
     lfs = [util.lf_detokenize(self.vocab.decode(pred)) for pred in preds]
     print("best guess", lfs[0], file=sys.stderr)
     lfs = [lf for lf in lfs if lf in self.lfs]
     if len(lfs) > 0:
         return lfs[0]
     return self.lfs[np.random.randint(len(self.lfs))]
コード例 #9
0
    def represent(utt):
        out = []
        if FLAGS.bert_features:
            utt_enc = torch.tensor([tokenizer.encode(utt)]).to(_device())
            with torch.no_grad():
                _, _, hiddens = representer(utt_enc)
                word_rep = hiddens[0].mean(dim=1)
                seq_rep = hiddens[-1].mean(dim=1)
            out.append(F.normalize(word_rep, dim=1))
            out.append(F.normalize(seq_rep, dim=1))

        if FLAGS.lex_features:
            utt_lex = np.zeros((1, len(vocab)), dtype=np.float32)
            for word in util.word_tokenize(utt):
                if word in vocab:
                    utt_lex[0, vocab[word]] = 1
            out.append(F.normalize(torch.tensor(utt_lex).to(_device()), dim=1))

        if len(out) == 1:
            return out[0].detach()
        else:
            return torch.cat(out, dim=1).detach()
コード例 #10
0
def tokenize_contexts(contexts:list, max_tokens=-1):
    tokenized_context = [UTIL.word_tokenize(context.strip()) if max_tokens == -1 else UTIL.word_tokenize(context.strip())[0:max_tokens]for context in contexts]
    return tokenized_context
コード例 #11
0
 def proxy_treatment_from_review(text):
     text = util.word_tokenize(text.lower())
     return int(len(set(text) & lex) > 0)