Example #1
0
    def sample(self):
        print("sampling...")
        # How much?
        sampled = []
        for q_id in self.q_group:
            query = self.query[q_id]
            ranked_list = self.q_group[q_id]
            if len(ranked_list) < 20:
                continue

            sample_space = []
            for span_list in debiased_sampling(ranked_list):
                for score, span in span_list:
                    sample_space.append((score, span))
            # Sample 5 pairs from ranked list
            # Sample 3 pairs, where one is from ranked_list and one is from other than ranked list
            print(sample_space)

            for i in range(self.n_sample_ranked):
                (doc_id_1,
                 score_1), (doc_id_2,
                            score_2) = random.sample(sample_space, 2)
                print((score_1, score_2))

                if score_1 < score_2:
                    sampled.append((query, doc_id_1, doc_id_2))
                else:
                    sampled.append((query, doc_id_2, doc_id_1))

            for i in range(self.n_sample_not_ranked):
                doc_id_1 = pick1(self.doc_ids)
                doc_id_2, _, _ = pick1(ranked_list)
                sampled.append((query, doc_id_1, doc_id_2))

        return sampled
Example #2
0
    def create_instances_from_documents(self, documents):
        documents = [doc for doc in documents if doc]
        max_num_tokens = self.max_seq_length - 3
        target_seq_length = max_num_tokens

        docs_as_chunks, target_inst_num = self.pool_chunks_from_docs(
            documents, target_seq_length)

        instances = []
        for _ in range(target_inst_num):
            chunk_1 = pick1(pick1(docs_as_chunks))

            m = self.rng.randint(1, len(chunk_1))
            tokens_a = flatten(chunk_1[:m])
            b_length = target_seq_length - len(tokens_a)
            if self.rng.random() < 0.5:
                chunk_2 = pick1(pick1(docs_as_chunks))
                tokens_b = flatten(chunk_2)[:b_length]
            else:
                tokens_b = flatten(chunk_1[m:])[:b_length]
            truncate_seq_pair(tokens_a, tokens_b, target_seq_length, self.rng)

            tokens, segment_ids = format_tokens_pair_n_segid(
                tokens_a, tokens_b)
            instance = SegmentInstance(tokens=tokens, segment_ids=segment_ids)
            instances.append(instance)

        return instances
Example #3
0
def generate_pairwise_combinations(neg_inst_list,
                                   pos_inst_list,
                                   verbose=False):
    insts = []
    if verbose:
        print("pos_insts", len(pos_inst_list))
        print("neg_insts", len(neg_inst_list))

    if not neg_inst_list or not pos_inst_list:
        return insts
    if len(pos_inst_list) > len(neg_inst_list):
        major_inst = pos_inst_list
        minor_inst = neg_inst_list
        pos_idx = 0
    else:
        major_inst = neg_inst_list
        minor_inst = pos_inst_list
        pos_idx = 1
    for idx, entry in enumerate(major_inst):
        entry2 = pick1(minor_inst)

        pos_entry = [entry, entry2][pos_idx]
        neg_entry = [entry, entry2][1 - pos_idx]
        insts.append((pos_entry, neg_entry))
    return insts
Example #4
0
    def create_instances_from_sent_list(self, sent_list):
        max_num_tokens = self.max_seq_length - 3
        target_seq_length = max_num_tokens
        print("pooling chunks")
        chunks: List[List[Token]] = pool_tokens(self.rng, sent_list,
                                                target_seq_length)

        target_inst_num = len(chunks)
        instances = []
        for _ in range(target_inst_num):
            chunk_1: List[Token] = pick1(chunks)
            if len(chunk_1) < 3:
                continue

            m = self.rng.randint(1, len(chunk_1))
            tokens_a = chunk_1[:m]
            b_length = target_seq_length - len(tokens_a)
            tokens_b = chunk_1[m:][:b_length]
            truncate_seq_pair(tokens_a, tokens_b, target_seq_length, self.rng)

            tokens, segment_ids = format_tokens_pair_n_segid(
                tokens_a, tokens_b)
            instance = SegmentInstance(tokens=tokens, segment_ids=segment_ids)
            instances.append(instance)

        return instances
Example #5
0
    def generate(self, data_id_manager, qids):
        missing_cnt = 0
        success_docs = 0
        missing_doc_qid = []
        counter = Counter()
        cut_list = [50, 100, 200, 300, 500, 1000, 999999999]
        for qid in qids:
            if qid not in self.resource.get_doc_for_query_d():
                assert not self.resource.query_in_qrel(qid)
                continue

            tokens_d = self.resource.get_doc_tokens_d(qid)
            q_tokens = self.resource.get_q_tokens(qid)

            pos_doc_id_list = []
            neg_doc_id_list = []
            for doc_id in self.resource.get_doc_for_query_d()[qid]:
                label = self.resource.get_label(qid, doc_id)
                if label:
                    pos_doc_id_list.append(doc_id)
                else:
                    neg_doc_id_list.append(doc_id)

            try:
                for pos_doc_id in pos_doc_id_list:
                    sampled_neg_doc_id = pick1(neg_doc_id_list)
                    pos_doc_tokens = tokens_d[pos_doc_id]
                    for cut in cut_list:
                        if len(pos_doc_tokens) < cut:
                            counter["pos_under_{}".format(cut)] += 1
                        else:
                            counter["pos_over_{}".format(cut)] += 1
                        neg_doc_tokens = tokens_d[sampled_neg_doc_id]
                        if len(neg_doc_tokens) < cut:
                            counter["neg_under_{}".format(cut)] += 1
                        else:
                            counter["neg_over_{}".format(cut)] += 1

                    inst = PairedInstance([], [], [], [], 0)
                    yield inst
                    success_docs += 1
            except KeyError:
                missing_cnt += 1
                missing_doc_qid.append(qid)
                if missing_cnt > 10 * 40:
                    print(missing_doc_qid)
                    raise

        for cut in cut_list:
            n_pos_short = counter["pos_under_{}".format(cut)]
            n_short = n_pos_short + counter["neg_under_{}".format(cut)]
            if n_short > 0:
                p_pos_if_short = n_pos_short / n_short
                print("P(Pos|Len<{})={}".format(cut, p_pos_if_short))
            else:
                print("P(Pos|Len<{})={}".format(cut, "div 0"))


        print_dict_tab(counter)
Example #6
0
def get_target_indices_random_over_09(passage_scores) -> List[int]:
    candidate = []
    for idx, s in enumerate(passage_scores):
        if s > 0.9:
            candidate.append(idx)

    if random.random() < 0.1:
        output = [0]
        if candidate:
            output.append(pick1(candidate))
        return output
    else:
        return [0]
Example #7
0
    def create_instances(self, topic, raw_docs, labeled_data):
        # Format: [CLS] [Abortion] [LABEL_FAVOR] ...(ukp text)...[SEP] [ABORTION] [LABEL_UNK] ..(clue text).. [SEP]
        topic_tokens = self.tokenizer.tokenize(topic.replace("_", " "))
        # TODO iterate docs, pool chunk
        # randomly draw and sometimes insert labeled one
        # encode and add to instances
        max_num_tokens = self.max_seq_length - 3 - 2 - 2 * len(topic_tokens)
        target_seq_length = max_num_tokens
        docs_as_chunks, target_inst_num = self.pool_chunks_from_docs(
            raw_docs, target_seq_length)

        instances = []
        for _ in range(target_inst_num):
            chunk_1 = pick1(pick1(docs_as_chunks))

            m = self.rng.randint(1, len(chunk_1))
            tokens_a = flatten(chunk_1[:m])
            b_length = target_seq_length - len(tokens_a)
            if self.rng.random() < self.ratio_labeled and labeled_data:
                label, tokens_b = pick1(labeled_data)
            else:
                if self.rng.random() < 0.5:
                    chunk_2 = pick1(pick1(docs_as_chunks))
                    tokens_b = flatten(chunk_2)[:b_length]
                else:
                    tokens_b = flatten(chunk_1[m:])[:b_length]
                label = -1
            truncate_seq_pair(tokens_a, tokens_b, target_seq_length, self.rng)

            swap = self.rng.random() < 0.5

            tokens, segment_ids = encode_label_and_token_pair(
                topic_tokens, label, tokens_b, tokens_a, swap)
            instance = SegmentInstance(tokens=tokens, segment_ids=segment_ids)
            instances.append(instance)

        return instances
Example #8
0
    def generate(self, data_id_manager, qids):
        missing_cnt = 0
        success_docs = 0
        missing_doc_qid = []
        ticker = TimeEstimator(len(qids))
        for qid in qids:
            if qid not in self.resource.get_doc_for_query_d():
                continue
            ticker.tick()
            docs: List[MSMarcoDoc] = load_per_query_docs(qid, None)
            docs_d = {d.doc_id: d for d in docs}

            q_tokens = self.resource.get_q_tokens(qid)
            pos_doc_id_list, neg_doc_id_list \
                = get_pos_neg_doc_ids_for_qid(self.resource, qid)

            def iter_passages(doc_id):
                doc = docs_d[doc_id]
                insts: List[Tuple[List, List]] = self.encoder.encode(q_tokens, doc.title, doc.body)

                for passage_idx, passage in enumerate(insts):
                    yield passage

            for pos_doc_id in pos_doc_id_list:
                sampled_neg_doc_id = pick1(neg_doc_id_list)
                try:
                    for passage_idx1, passage1 in enumerate(iter_passages(pos_doc_id)):
                        for passage_idx2, passage2 in enumerate(iter_passages(sampled_neg_doc_id)):
                            tokens_seg1, seg_ids1 = passage1
                            tokens_seg2, seg_ids2 = passage2

                            data_id = data_id_manager.assign({
                                'doc_id1': pos_doc_id,
                                'passage_idx1': passage_idx1,
                                'doc_id2': sampled_neg_doc_id,
                                'passage_idx2': passage_idx2,
                            })
                            inst = PairedInstance(tokens_seg1, seg_ids1, tokens_seg2, seg_ids2, data_id)
                            yield inst
                    success_docs += 1
                except KeyError:
                    missing_cnt += 1
                    missing_doc_qid.append(qid)
                    if missing_cnt > 10:
                        print(missing_doc_qid)
                        print("success: ", success_docs)
                        raise KeyError
Example #9
0
    def get_random_batch(self, batch_size):
        data = []
        for _ in range(batch_size):
            data_idx = random.randint(0, self.data - 1)
            input_ids, input_mask, segment_ids, y = self.data[data_idx]
            appeared_words = self.data_info[data_idx]
            if appeared_words:
                word = pick1(appeared_words)
                d_input_ids, d_input_mask = self.dict[word.word]
                d_location_ids = word.location
            else:
                d_input_ids = [0] * self.max_def_length
                d_input_mask = [0] * self.max_def_length
                d_location_ids = [0] * self.max_d_loc

            e = input_ids, input_mask, segment_ids, d_input_ids, d_input_mask, d_location_ids, y
            data.append(e)
        return get_batches_ex(data, batch_size, 7)[0]
Example #10
0
    def generate(self, data_id_manager, qids):
        missing_cnt = 0
        success_docs = 0
        missing_doc_qid = []
        for qid in qids:
            if qid not in self.resource.get_doc_for_query_d():
                assert not self.resource.query_in_qrel(qid)
                continue

            tokens_d = self.resource.get_doc_tokens_d(qid)
            q_tokens = self.resource.get_q_tokens(qid)

            pos_doc_id_list, neg_doc_id_list \
                = get_pos_neg_doc_ids_for_qid(self.resource, qid)

            def iter_passages(doc_id):
                doc_tokens = tokens_d[doc_id]
                insts: List[Tuple[List, List]] = self.encoder.encode(q_tokens, doc_tokens)

                for passage_idx, passage in enumerate(insts):
                    yield passage
            try:
                for pos_doc_id in pos_doc_id_list:
                    sampled_neg_doc_id = pick1(neg_doc_id_list)
                    for passage_idx1, passage1 in enumerate(iter_passages(pos_doc_id)):
                        for passage_idx2, passage2 in enumerate(iter_passages(sampled_neg_doc_id)):
                            tokens_seg1, seg_ids1 = passage1
                            tokens_seg2, seg_ids2 = passage2

                            data_id = data_id_manager.assign({
                                'doc_id1': pos_doc_id,
                                'passage_idx1': passage_idx1,
                                'doc_id2': sampled_neg_doc_id,
                                'passage_idx2': passage_idx2,
                            })
                            inst = PairedInstance(tokens_seg1, seg_ids1, tokens_seg2, seg_ids2, data_id)
                            yield inst
                    success_docs += 1
            except KeyError:
                missing_cnt += 1
                missing_doc_qid.append(qid)
                if missing_cnt > 10:
                    print(missing_doc_qid)
                    raise
Example #11
0
    def get_random_batch(self, batch_size):
        problem_data = []
        def_entries_list = []
        for data_idx in self.get_random_data_indices(batch_size):
            appeared_words = self.data_info[data_idx]
            if appeared_words:
                word = pick1(appeared_words)
                def_entries_list.append(self.dict[word.word])
                d_location_ids = word.location
            else:
                def_entries_list.append([])
                d_location_ids = [0] * self.max_d_loc

            e = self.add_location(self.data[data_idx], d_location_ids)
            problem_data.append(e)

        self.drop_definitions(def_entries_list, self.def_per_batch)
        batch = self._get_batch(problem_data, def_entries_list, batch_size)
        return batch
Example #12
0
        def convert(
            target_pair: Tuple[QCKQueryWToken, List[KDPWToken]],
            other_pairs: List[Tuple[QCKQueryWToken, List[KDPWToken]]]
        ) -> Iterable[Payload]:
            target_query, target_kdp_list = target_pair
            candidates = self.candidates_dict[target_query.query_id]
            candidates_w_tokens = [
                QCKCandidateWToken.from_qck_candidate(self.tokenizer, c)
                for c in candidates
            ]
            num_inst_expectation = len(target_kdp_list) * len(candidates)
            if num_inst_expectation > 1000 * 1000:
                print(target_query)
                print(len(target_kdp_list))
                print(len(candidates))

            def get_insts_per_candidate(candidate: QCKCandidateWToken,
                                        query: QCKQueryWToken,
                                        kdp_list: List[KDPWToken]) -> Payload:
                kdp_list = kdp_list[:self.k_group_size]

                kdp_token_list = []
                for p_idx, kdp in enumerate(kdp_list):
                    kdp_token_list.append(kdp.sub_tokens)

                info = {
                    'query': get_light_qckquery(query),
                    'candidate': get_light_qckcandidate(candidate),
                    'kdpl': lmap(get_light_kdp, kdp_list)
                }
                inst = Payload(kdp_list=kdp_token_list,
                               text1=query.tokens,
                               text2=candidate.tokens,
                               data_id=data_id_manager.assign(info),
                               is_correct=self._is_correct(query, candidate))
                return inst

            for c_w_token in candidates_w_tokens:
                yield get_insts_per_candidate(c_w_token, target_query,
                                              target_kdp_list)
                other_query, other_kdp_list = pick1(other_pairs)
                yield get_insts_per_candidate(c_w_token, other_query,
                                              other_kdp_list)
Example #13
0
    def get_all_batches(self, batch_size, f_return_indices=False):
        problem_data = []
        def_entries_list = []
        all_indice = []

        for data_idx in self.data_info.keys():
            all_indice.append(data_idx)
            appeared_words = self.data_info[data_idx]
            if appeared_words:
                word = pick1(appeared_words)
                def_entries_list.append(self.dict[word.word])
                d_location_ids = word.location
            else:
                def_entries_list.append([])
                d_location_ids = [0] * self.max_d_loc

            e = self.add_location(self.data[data_idx], d_location_ids)
            problem_data.append(e)

        n_insts = len(def_entries_list)
        assert n_insts == len(problem_data)

        batches = []
        for i in range(0, n_insts, batch_size):
            local_batch_len = min(batch_size, n_insts - i)
            if local_batch_len < batch_size:
                break
            current_problems = problem_data[i:i + local_batch_len]
            current_entries = def_entries_list[i:i + local_batch_len]
            self.drop_definitions(current_entries, self.def_per_batch)
            batch = self._get_batch(current_problems, current_entries,
                                    batch_size)
            if not f_return_indices:
                batches.append(batch)
            else:
                batches.append((all_indice[i:i + local_batch_len], batch))

        return batches
Example #14
0
 def get_dev_batch_from(self, task_idx):
     batches = self.dev_batch_list[task_idx]
     batch = pick1(batches)
     return task_idx, batch
Example #15
0
 def sample_batch(self, source_batch_list):
     task_idx = self.sample_task()
     batches = source_batch_list[task_idx]
     batch = pick1(batches)
     return task_idx, batch
Example #16
0
 def sample_leaf_node(self):
     return pick1(self.get_leaf_nodes())
Example #17
0
 def sample_any_node(self):
     return pick1(self.get_all_nodes())
def generate_training_data(data_id):
    num_samples_list = open(os.path.join(output_path, "lookup_n", data_id),
                            "r").readlines()
    p = os.path.join(output_path, "example_loss{}.pickle".format(data_id))
    loss_outputs_list = pickle.load(open(p, "rb"))
    loss_outputs = []
    for e in loss_outputs_list:
        loss_outputs.extend(e["masked_lm_example_loss"])
    print("Total of {} loss outputs".format(len(loss_outputs)))
    feature_itr = load_record_v1(
        os.path.join(output_path, "lookup_example", data_id))

    instance_idx = 0
    writer = tf.python_io.TFRecordWriter(
        os.path.join(output_path, "lookup_train", data_id))

    n = len(num_samples_list)
    for i in range(n):
        f_feed_dictionary = random.random() < 0.5
        n_sample = int(num_samples_list[i])
        assert n_sample > 0
        first_inst = feature_itr.__next__()
        max_seq_len = len(take(first_inst["input_ids"]))

        if instance_idx + n_sample >= len(loss_outputs):
            break

        if n_sample == 1:
            continue

        no_dict_loss = loss_outputs[instance_idx]
        instance_idx += 1
        all_samples = []
        good_locations = []
        for j in range(1, n_sample):
            feature = feature_itr.__next__()
            d_location_ids = take(feature["d_location_ids"])
            loss = loss_outputs[instance_idx]
            if loss < no_dict_loss * 0.9:
                good_locations.extend(
                    [idx for idx in d_location_ids if idx > 0])
            all_samples.append(feature)
            instance_idx += 1

        lookup_idx = list([0 for _ in range(max_seq_len)])
        for loc in good_locations:
            lookup_idx[loc] = 1

        if f_feed_dictionary:
            base_feature = pick1(all_samples)
        else:
            base_feature = first_inst

        new_features = collections.OrderedDict()
        for key in base_feature:
            new_features[key] = btd.create_int_feature(take(base_feature[key]))

        new_features["lookup_idx"] = btd.create_int_feature(lookup_idx)

        example = tf.train.Example(features=tf.train.Features(
            feature=new_features))
        writer.write(example.SerializeToString())

    writer.close()
Example #19
0
 def pick_short_sent():
     tokens = pick1(sents)
     while len(tokens) > 100:
         tokens = pick1(sents)
     return tokens