コード例 #1
0
def encode_payload(idx):
    doc_pairs = pickle.load(
        open("../output/plain_pair_0{}.pickle".format(idx), "rb"))
    vocab_filename = "bert_voca.txt"
    voca_path = os.path.join(data_path, vocab_filename)
    max_sequence = 200
    encoder_unit = EncoderUnit(max_sequence, voca_path)

    result = []
    for q1, d1, q2, d2 in doc_pairs:
        enc_run_1 = []
        for text in split_text(d1):
            entry = encoder_unit.encode_pair(q1, text)
            enc_run_1.append((entry["input_ids"], entry["input_mask"],
                              entry["segment_ids"]))

        enc_run_2 = []
        for text in split_text(d2):
            entry = encoder_unit.encode_pair(q2, text)
            enc_run_2.append((entry["input_ids"], entry["input_mask"],
                              entry["segment_ids"]))
        result.append((enc_run_1, enc_run_2))

    filename = os.path.join(output_path, "merger_train_{}.pickle".format(idx))
    pickle.dump(result, open(filename, "wb"))
コード例 #2
0
ファイル: data_sampler.py プロジェクト: clover3/Chair
 def __init__(self, port, max_sequence):
     self.port = port
     tprint("Loading data sampler")
     self.data_sampler = DataSampler.init_from_pickle("robust04")
     vocab_filename = "bert_voca.txt"
     voca_path = os.path.join(data_path, vocab_filename)
     self.encoder_unit = EncoderUnit(max_sequence, voca_path)
     self.pair_generator = self.data_sampler.pair_generator()
コード例 #3
0
ファイル: mscore.py プロジェクト: clover3/Chair
    def __init__(self, max_sequence, vocab_filename, voca_size):
        self.train_data = None
        self.dev_data = None
        self.test_data = None

        voca_path = os.path.join(data_path, vocab_filename)
        assert os.path.exists(voca_path)
        print(voca_path)

        self.mscore = read_mscore_valid()
        self.mscore_dict = dict(self.mscore)
        self.train_topics, self.dev_topics = self.held_out(left(self.mscore))

        self.lower_case = True
        self.sep_char = "#"
        self.encoder = FullTokenizerWarpper(voca_path)
        self.voca_size = voca_size
        self.dev_explain = None
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)
        self.client = TextReaderClient()

        class UniformSampler:
            def __init__(self, topics):
                self.sample_space = topics

            def sample(self):
                return random.sample(self.sample_space, 2)


        class BiasSampler:
            def __init__(self, topics, score_dict):
                self.sample_space = []
                self.sample_group = dict()

                def score2key(score):
                    return int(math.log(score+1, 1.1))

                for topic in topics:
                    key = score2key(score_dict[topic])
                    if key not in self.sample_group:
                        self.sample_group[key] = []
                    self.sample_group[key].append(topic)

                self.sample_space = list(self.sample_group.keys())


            # Sample from all group
            def sample(self):
                def pick1(l):
                    return l[random.randrange(len(l))]

                g1, g2 = random.sample(self.sample_space, 2)
                t1 = pick1(self.sample_group[g1])
                t2 = pick1(self.sample_group[g2])
                return t1, t2

        self.train_sampler = BiasSampler(self.train_topics, self.mscore_dict)
        self.dev_sampler = BiasSampler(self.dev_topics, self.mscore_dict)
コード例 #4
0
ファイル: data_sampler.py プロジェクト: clover3/Chair
 def __init__(self, max_sequence):
     tprint("Loading data sampler")
     #mem_path = "/dev/shm/robust04.pickle"
     #self.data_sampler = pickle.load(open(mem_path, "rb"))
     self.data_sampler = DataSampler.init_from_pickle("robust04")
     vocab_filename = "bert_voca.txt"
     voca_path = os.path.join(data_path, vocab_filename)
     self.encoder_unit = EncoderUnit(max_sequence, voca_path)
     self.pair_generator = self.data_sampler.pair_generator()
コード例 #5
0
ファイル: protest.py プロジェクト: clover3/Chair
    def __init__(self, max_sequence, vocab_filename, voca_size):
        self.train_data = None
        self.dev_data = None
        self.test_data = None

        voca_path = os.path.join(data_path, vocab_filename)
        assert os.path.exists(voca_path)
        print(voca_path)

        self.lower_case = True
        self.sep_char = "#"
        self.encoder = FullTokenizerWarpper(voca_path)
        self.voca_size = voca_size
        self.dev_explain = None
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)
コード例 #6
0
ファイル: rerank_encode.py プロジェクト: clover3/Chair
def encode(input_path, output_path):
    payload = pickle.load(open(input_path, "rb"))
    vocab_filename = "bert_voca.txt"
    voca_path = os.path.join(data_path, vocab_filename)
    max_sequence = 512
    encoder_unit = EncoderUnit(max_sequence, voca_path)
    result = []
    print("Encode...")
    for q_id, doc_id, doc_query_list in payload:
        runs = []
        for query, sent in doc_query_list:
            entry = encoder_unit.encode_pair(query, sent)
            runs.append(entry)
        result.append((doc_id, q_id, runs))
    pickle.dump(result, open(output_path, "wb"))
    print("Done...")
コード例 #7
0
ファイル: data_sampler.py プロジェクト: clover3/Chair
class DataWriter:
    def __init__(self, max_sequence):
        tprint("Loading data sampler")
        #mem_path = "/dev/shm/robust04.pickle"
        #self.data_sampler = pickle.load(open(mem_path, "rb"))
        self.data_sampler = DataSampler.init_from_pickle("robust04")
        vocab_filename = "bert_voca.txt"
        voca_path = os.path.join(data_path, vocab_filename)
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)
        self.pair_generator = self.data_sampler.pair_generator()

    def encode_pair(self, instance):
        query, case1, case2 = instance
        for y, sent in [case1, case2]:
            entry = self.encoder_unit.encode_pair(query, sent)
            yield entry["input_ids"], entry["input_mask"], entry["segment_ids"]

    def get_data(self, data_size):
        assert data_size % 2 == 0
        result = []
        ticker = TimeEstimator(data_size, sample_size=100)
        while len(result) < data_size:
            raw_inst = self.pair_generator.__next__()
            result += list(self.encode_pair(raw_inst))
            ticker.tick()
        return result

    def write(self, path, num_data):
        assert num_data % 2 == 0
        pickle.dump(self.get_data(num_data), open(path, "wb"))
コード例 #8
0
    def __init__(self, max_sequence, vocab_filename, voca_size, is_span):
        self.train_data = None
        self.dev_data = None
        self.test_data = None

        voca_path = os.path.join(data_path, vocab_filename)
        self.encoder = SubwordTextEncoder(voca_path)

        self.lower_case = True
        self.sep_char = "#"
        self.encoder = FullTokenizerWarpper(voca_path)
        self.voca_size = voca_size
        self.dev_explain = None
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)
        self.max_seq = max_sequence

        self.question = [
            "What is title of the controversy?",
            "What is the controversy about?"
        ]
        if not is_span:
            self.q_id = 0
        else:
            self.q_id = 1
        self.is_span = is_span
        self.text_offset = len(self.encoder.encode(
            self.question[self.q_id])) + 2

        data = load_annotation()
        self.all_data = self.generate_data(data)
        self.train_data, self.dev_data = self.held_out(self.all_data)
コード例 #9
0
def encode_payload(idx):
    doc_pairs = pickle.load(
        open("../output/plain_pair_0{}.pickle".format(idx), "rb"))
    vocab_filename = "bert_voca.txt"
    voca_path = os.path.join(data_path, vocab_filename)
    max_sequence = 512
    encoder_unit = EncoderUnit(max_sequence, voca_path)

    result = []
    for q1, d1, q2, d2 in doc_pairs:

        for query, text in [(q1, d1), (q2, d2)]:
            entry = encoder_unit.encode_pair(query, text)
            result.append((entry["input_ids"], entry["input_mask"],
                           entry["segment_ids"]))

    filename = os.path.join(output_path, "fad_train_{}.pickle".format(idx))
    pickle.dump(result, open(filename, "wb"))
コード例 #10
0
    def __init__(self, max_sequence, vocab_filename, voca_size):
        self.train_data = None
        self.dev_data = None
        self.test_data = None

        inst_per_query = 30
        self.generator = gen_trainable_iterator(inst_per_query)
        self.iter = iter(self.generator)
        voca_path = os.path.join(data_path, vocab_filename)
        assert os.path.exists(voca_path)
        print(voca_path)

        self.lower_case = True
        self.sep_char = "#"
        self.encoder = FullTokenizerWarpper(voca_path)
        self.voca_size = voca_size
        self.dev_explain = None
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)
コード例 #11
0
 def example_generator(self, split_name):
     entries = amsterdam.load_data_split(split_name)
     print(split_name, "{} items".format(len(entries)))
     encoder_unit = EncoderUnit(self.max_sequence, self.voca_path)
     if len(entries) > 200:
         fn = lambda x: self.encode_entries(x, encoder_unit)
         return parallel_run(entries, fn, 10)
     else:
         return self.encode_entries(entries, encoder_unit)
コード例 #12
0
    def encode_docs(self, docs):
        encoder_unit = EncoderUnit(self.max_sequence, self.voca_path)

        def enc(text):
            entry = encoder_unit.encode_text_single(text)
            return entry["input_ids"], entry["input_mask"], entry[
                "segment_ids"]

        return lmap(enc, docs)
コード例 #13
0
class DataLoader:
    def __init__(self, max_sequence, vocab_filename, voca_size):
        self.train_data = None
        self.dev_data = None
        self.test_data = None

        inst_per_query = 30
        self.generator = gen_trainable_iterator(inst_per_query)
        self.iter = iter(self.generator)
        voca_path = os.path.join(data_path, vocab_filename)
        assert os.path.exists(voca_path)
        print(voca_path)

        self.lower_case = True
        self.sep_char = "#"
        self.encoder = FullTokenizerWarpper(voca_path)
        self.voca_size = voca_size
        self.dev_explain = None
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)

    def get_train_data(self, data_size):
        assert data_size % 2 == 0
        result = []
        while len(result) < data_size:
            raw_inst = self.iter.__next__()
            result += list(self.encode_pair(raw_inst))

        return result

    def get_dev_data(self):
        result = []
        for i in range(160):
            raw_inst = self.iter.__next__()
            result += list(self.encode_pair(raw_inst))

        return result

    def encode_pair(self, instance):
        query, case1, case2 = instance

        for y, sent in [case1, case2]:
            entry = self.encode(query, sent)
            yield entry["input_ids"], entry["input_mask"], entry[
                "segment_ids"], y

    def encode(self, query, text):
        tokens_a = self.encoder.encode(query)
        tokens_b = self.encoder.encode(text)
        return self.encoder_unit.encode_inner(tokens_a, tokens_b)
コード例 #14
0
ファイル: robust_rerank.py プロジェクト: clover3/Chair
def encode_pred_set(top_k, max_sequence):
    vocab_filename = "bert_voca.txt"
    voca_path = os.path.join(data_path, vocab_filename)
    encoder_unit = EncoderUnit(max_sequence, voca_path)
    collection = trec.load_robust(trec.robust_path)
    print("Collection has #docs :", len(collection))
    #queries = load_robust04_query()
    queries = load_robust04_desc()
    ranked_lists = load_2k_rank()
    window_size = max_sequence * 3

    payload = []
    with ProcessPoolExecutor(max_workers=8) as executor:
        future_list = []
        for q_id in ranked_lists:
            ranked = ranked_lists[q_id]
            ranked.sort(key=lambda x: x[1])
            assert ranked[0][1] == 1
            print(q_id)
            doc_ids = []
            for doc_id, rank, score, in ranked[:top_k]:
                doc_ids.append(doc_id)

                raw_query = queries[q_id]
                content = collection[doc_id]
                idx = 0
                f_list = []
                while idx < len(content):
                    span = content[idx:idx + window_size]
                    f = executor.submit(encoder_unit.encode_pair, raw_query,
                                        span)
                    idx += window_size
                    f_list.append(f)
                #runs_future = executor.submit(encoder_unit.encode_long_text, raw_query, content)
                future_list.append((doc_id, q_id, f_list))

        def get_flist(f_list):
            r = []
            for f in f_list:
                r.append(f.result())
            return r

        for doc_id, q_id, f_list in future_list:
            payload.append((doc_id, q_id, get_flist(f_list)))

    pickle.dump(payload, open("payload_B_{}.pickle".format(top_k), "wb"))
コード例 #15
0
ファイル: data_sampler.py プロジェクト: clover3/Chair
class Server:
    def __init__(self, port, max_sequence):
        self.port = port
        tprint("Loading data sampler")
        self.data_sampler = DataSampler.init_from_pickle("robust04")
        vocab_filename = "bert_voca.txt"
        voca_path = os.path.join(data_path, vocab_filename)
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)
        self.pair_generator = self.data_sampler.pair_generator()

    def encode_pair(self, instance):
        query, case1, case2 = instance
        for y, sent in [case1, case2]:
            entry = self.encoder_unit.encode_pair(query, sent)
            yield entry["input_ids"], entry["input_mask"], entry[
                "segment_ids"], y

    def get_data(self, data_size):
        assert data_size % 2 == 0
        result = []
        while len(result) < data_size:
            raw_inst = self.pair_generator.__next__()
            result += list(self.encode_pair(raw_inst))
        return result

    def start(self):
        class RequestHandler(SimpleXMLRPCRequestHandler):
            rpc_paths = ('/RPC2', )

        tprint("Preparing server")
        server = SimpleXMLRPCServer(
            ("0.0.0.0", self.port),
            requestHandler=RequestHandler,
            allow_none=True,
        )
        server.register_introspection_functions()

        server.register_function(self.get_data, 'get_data')
        tprint("Waiting")
        server.serve_forever()
コード例 #16
0
ファイル: data_sampler_g.py プロジェクト: clover3/Chair
class DataEncoder:
    vocab_filename = "bert_voca.txt"
    voca_path = os.path.join(data_path, vocab_filename)
    max_sequence = 200
    encoder_unit = EncoderUnit(max_sequence, voca_path)

    def __init__(self, port):
        self.port = port

    def encode(self, payload):
        result = []
        print("Encode...")
        with ProcessPoolExecutor(max_workers=8) as executor:
            future_entry_list = []
            for query, text1, text2 in payload:
                for sent in [text1, text2]:
                    future_entry_list.append(
                        executor.submit(self.encoder_unit.encode_pair, query,
                                        sent))
            for future_entry in future_entry_list:
                entry = future_entry.result()
                result.append((entry["input_ids"], entry["input_mask"],
                               entry["segment_ids"]))
        print("Done...")
        return result

    def start_server(self):
        class RequestHandler(SimpleXMLRPCRequestHandler):
            rpc_paths = ('/RPC2', )

        print("Preparing server")
        server = SimpleXMLRPCServer(
            ("0.0.0.0", self.port),
            requestHandler=RequestHandler,
            allow_none=True,
        )
        server.register_introspection_functions()

        server.register_function(self.encode, 'encode')
        server.serve_forever()
コード例 #17
0
ファイル: protest.py プロジェクト: clover3/Chair
class DataLoader:
    def __init__(self, max_sequence, vocab_filename, voca_size):
        self.train_data = None
        self.dev_data = None
        self.test_data = None

        voca_path = os.path.join(data_path, vocab_filename)
        assert os.path.exists(voca_path)
        print(voca_path)

        self.lower_case = True
        self.sep_char = "#"
        self.encoder = FullTokenizerWarpper(voca_path)
        self.voca_size = voca_size
        self.dev_explain = None
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)

    def get_train_data(self):
        if self.train_data is None:
            self.train_data = list(self.example_generator("train"))
        return self.train_data

    def get_dev_data(self):
        if self.dev_data is None:
            self.dev_data = list(self.example_generator("dev"))
        return self.dev_data

    def example_generator(self, split_name):
        X, Y = load_protest.load_data(split_name)
        for idx, x in enumerate(X):
            name, text = x
            l = Y[name]
            entry = self.encode(text)
            yield entry["input_ids"], entry["input_mask"], entry[
                "segment_ids"], l

    def encode(self, text):
        tokens_a = self.encoder.encode(text)
        return self.encoder_unit.encode_inner(tokens_a, [])
コード例 #18
0
ファイル: mscore.py プロジェクト: clover3/Chair
class DataLoader:
    def __init__(self, max_sequence, vocab_filename, voca_size):
        self.train_data = None
        self.dev_data = None
        self.test_data = None

        voca_path = os.path.join(data_path, vocab_filename)
        assert os.path.exists(voca_path)
        print(voca_path)

        self.mscore = read_mscore_valid()
        self.mscore_dict = dict(self.mscore)
        self.train_topics, self.dev_topics = self.held_out(left(self.mscore))

        self.lower_case = True
        self.sep_char = "#"
        self.encoder = FullTokenizerWarpper(voca_path)
        self.voca_size = voca_size
        self.dev_explain = None
        self.encoder_unit = EncoderUnit(max_sequence, voca_path)
        self.client = TextReaderClient()

        class UniformSampler:
            def __init__(self, topics):
                self.sample_space = topics

            def sample(self):
                return random.sample(self.sample_space, 2)


        class BiasSampler:
            def __init__(self, topics, score_dict):
                self.sample_space = []
                self.sample_group = dict()

                def score2key(score):
                    return int(math.log(score+1, 1.1))

                for topic in topics:
                    key = score2key(score_dict[topic])
                    if key not in self.sample_group:
                        self.sample_group[key] = []
                    self.sample_group[key].append(topic)

                self.sample_space = list(self.sample_group.keys())


            # Sample from all group
            def sample(self):
                def pick1(l):
                    return l[random.randrange(len(l))]

                g1, g2 = random.sample(self.sample_space, 2)
                t1 = pick1(self.sample_group[g1])
                t2 = pick1(self.sample_group[g2])
                return t1, t2

        self.train_sampler = BiasSampler(self.train_topics, self.mscore_dict)
        self.dev_sampler = BiasSampler(self.dev_topics, self.mscore_dict)


    def get_train_data(self, size):
        return self.generate_data(self.train_sampler.sample, size)

    def get_dev_data(self, size):
        return self.generate_data(self.dev_sampler.sample, size)

    def generate_data(self, sample_fn, size):
        pair_n = int(size / 2)
        assert pair_n * 2 == size
        topic_pairs = self.sample_pairs(sample_fn, pair_n)
        result = []
        for topic_pair in topic_pairs:
            t1, t2 = topic_pair
            inst = (self.retrieve(t1),  self.retrieve(t2))
            result += list(self.encode_pair(inst))

        return result

    def retrieve(self, topic):
        r = self.client.retrieve(topic)
        if not r:
            print(topic)
        return r

    def encode_pair(self, sent_pair):
        for sent in sent_pair:
            entry = self.encode(sent)
            yield entry["input_ids"], entry["input_mask"], entry["segment_ids"]

    def encode(self, text):
        tokens_a = self.encoder.encode(text)
        return self.encoder_unit.encode_inner(tokens_a, [])

    def sample_pairs(self, sample_fn, n_pairs):
        result = []
        for i in range(n_pairs):
            selected = sample_fn()
            t1 = selected[0]
            t2 = selected[1]
            score1 = self.mscore_dict[t1]
            score2 = self.mscore_dict[t2]
            if score1 < score2:
                result.append((t1, t2))
            else:
                result.append((t2, t1))
        return result

    def held_out(self,topics):
        heldout_size = int(len(topics) * 0.1)
        dev_topics = set(random.sample(topics, heldout_size))
        train_topics = set(topics) - dev_topics
        return train_topics, dev_topics