Ejemplo n.º 1
0
def load_data(batch_size):
    train_data = load_pickle_from(
        os.path.join(data_path, "msmarco", "train.pickle"))
    dev_data = load_pickle_from(
        os.path.join(data_path, "msmarco", "dev.pickle"))

    train_batches = get_batches_ex(train_data, batch_size, 4)
    dev_batches = get_batches_ex(dev_data, batch_size, 4)
    return train_batches, dev_batches
Ejemplo n.º 2
0
def batch_iter_from_entry_iter(batch_size, entry_iter):
    batch = []
    for entry in entry_iter:
        batch.append(entry)
        if len(batch) == batch_size:
            yield get_batches_ex(batch, batch_size, 4)[0]
            batch = []

    if batch:
        yield get_batches_ex(batch, batch_size, 4)[0]
Ejemplo n.º 3
0
def fetch_hidden_vector(hparam, vocab_size, data, model_path):
    task = transformer_nli_hidden(hparam, vocab_size, 0, False)
    sess = init_session()
    sess.run(tf.global_variables_initializer())

    load_model_w_scope(sess, model_path, ["bert"])
    batches = get_batches_ex(data, hparam.batch_size, 4)

    def batch2feed_dict(batch):
        x0, x1, x2, y = batch
        feed_dict = {
            task.x_list[0]: x0,
            task.x_list[1]: x1,
            task.x_list[2]: x2,
            task.y: y,
        }
        return feed_dict

    def pred_fn():
        outputs = []
        for batch in batches:
            x0, x1, x2, y = batch
            all_layers, emb_outputs = sess.run(
                [task.all_layers, task.embedding_output],
                feed_dict=batch2feed_dict(batch))
            outputs.append((all_layers, emb_outputs, x0))

        return outputs

    return pred_fn()
Ejemplo n.º 4
0
def predict_for_view(hparam, nli_setting, data_loader, data_id, model_path,
                     run_name, modeling_option, tags):
    print("predict_nli_ex")
    print("Modeling option: ", modeling_option)
    enc_payload, plain_payload = data_loader.load_plain_text(data_id)
    batches = get_batches_ex(enc_payload, hparam.batch_size, 3)

    task = transformer_nli_pooled(hparam, nli_setting.vocab_size)

    explain_predictor = ExplainPredictor(len(tags),
                                         task.model.get_sequence_output(),
                                         modeling_option)
    sess = init_session()
    sess.run(tf.global_variables_initializer())
    load_model(sess, model_path)

    out_entries = []
    for batch in batches:
        x0, x1, x2 = batch
        logits, ex_logits, = sess.run(
            [task.logits, explain_predictor.get_score()],
            feed_dict={
                task.x_list[0]: x0,
                task.x_list[1]: x1,
                task.x_list[2]: x2,
            })

        for i in range(len(x0)):
            e = x0[i], logits[i], tuple_list_select(ex_logits, i)
            out_entries.append(e)

    save_to_pickle(out_entries, "save_view_{}_{}".format(run_name, data_id))
Ejemplo n.º 5
0
def get_payload(input_path, nli_setting, batch_size) -> Tuple[Any, List]:
    initial_text = load_p_h_pair_text(input_path)
    voca_path = os.path.join(data_path, nli_setting.vocab_filename)
    encoder_unit = EncoderUnitPlain(nli_setting.seq_length, voca_path)
    info_entries, raw_payload = deletion_gen(initial_text,
                                             encoder_unit.encode_pair)
    batches = get_batches_ex(raw_payload, batch_size, 4)
    return batches, info_entries
Ejemplo n.º 6
0
    def _get_batch(self, a_instances, def_instances, batch_size):
        assert len(a_instances) <= batch_size
        assert len(def_instances) <= self.def_per_batch
        ab_map, batch_defs = self.pack_data(def_instances, self.max_def_length,
                                            self.def_per_batch)
        ab_map = np.expand_dims(ab_map, 1)
        a_part = get_batches_ex(a_instances, batch_size, 5)[0]
        b_part = get_batches_ex(batch_defs, self.def_per_batch, 3)[0]

        if not self.use_ab_mapping_mask:
            batch = a_part + b_part + [ab_map]
        else:
            ab_mapping_mask = self._get_ab_mapping_mask(
                ab_map, batch_size, self.def_per_batch)
            batch = a_part + b_part + [ab_map] + [ab_mapping_mask]

        return batch
Ejemplo n.º 7
0
 def forward_run(inputs):
     batches = get_batches_ex(inputs, self.hp.batch_size, 2)
     logit_list = []
     for batch in batches:
         x, y, = batch
         logits,  = self.sess.run([self.sout, ],
                                    feed_dict={
                                        self.input_text: x,
                                        self.dropout_prob: 1.0,
                                    })
         logit_list.append(logits)
     return np.concatenate(logit_list)
Ejemplo n.º 8
0
def predict_nli_ex(hparam, nli_setting, data_loader, explain_tag, data_id,
                   model_path, run_name, modeling_option):
    print("predict_nli_ex")
    print("Modeling option: ", modeling_option)
    enc_payload, plain_payload = data_loader.load_plain_text(data_id)
    batches = get_batches_ex(enc_payload, hparam.batch_size, 3)

    predictor = NLIExPredictor(hparam, nli_setting, model_path,
                               modeling_option)
    ex_logits = predictor.predict_ex(explain_tag, batches)
    pred_list = predict_translate(ex_logits, data_loader, enc_payload,
                                  plain_payload)
    save_to_pickle(pred_list, "pred_{}_{}".format(run_name, data_id))
Ejemplo n.º 9
0
 def forward_run(inputs):
     batches = get_batches_ex(inputs, hparam.batch_size, 3)
     logit_list = []
     for batch in batches:
         x0, x1, x2 = batch
         soft_out, = sess.run([
             sout,
         ],
                              feed_dict={
                                  task.x_list[0]: x0,
                                  task.x_list[1]: x1,
                                  task.x_list[2]: x2,
                              })
         logit_list.append(soft_out)
     return np.concatenate(logit_list)
Ejemplo n.º 10
0
    def get_data(self, batch_size, n_batches):
        st = self.cur_idx
        ed = self.cur_idx + int((batch_size * n_batches) / 2)
        if ed > len(self.cur_data):
            self.cur_data = self.load_next_data()
            st = self.cur_idx
            ed = self.cur_idx + batch_size * n_batches
        raw_data = self.cur_data[st:ed]

        enc_data = []
        for t in raw_data:
            enc_data += self.encode(t)
        self.cur_idx = ed
        batches = get_batches_ex(enc_data, batch_size, 3)
        return batches
Ejemplo n.º 11
0
 def forward_run(inputs):
     batches = get_batches_ex(inputs, self.hp.batch_size, 3)
     logit_list = []
     for batch in batches:
         x0, x1, x2 = batch
         logits, = self.sess.run(
             [
                 self.model.sout,
             ],
             feed_dict={
                 self.model.x_list[0]: x0,
                 self.model.x_list[1]: x1,
                 self.model.x_list[2]: x2,
             })
         logit_list.append(logits)
     return np.concatenate(logit_list)
Ejemplo n.º 12
0
        def forward_runs(insts):
            alt_batches = get_batches_ex(insts, hparam.batch_size, 3)
            alt_logits = []
            for batch in alt_batches:
                enc, att = sess.run(emb_outputs,
                                    feed_dict=feed_end_input(batch))
                logits, = sess.run([
                    softmax_out,
                ],
                                   feed_dict={
                                       task.encoded_embedding_in: enc,
                                       task.attention_mask_in: att
                                   })

                alt_logits.append(logits)
            alt_logits = np.concatenate(alt_logits)
            return alt_logits
Ejemplo n.º 13
0
    def get_random_batch(self, batch_size):
        data = []
        for _ in range(batch_size):
            data_idx = random.randint(0, self.data - 1)
            input_ids, input_mask, segment_ids, y = self.data[data_idx]
            appeared_words = self.data_info[data_idx]
            if appeared_words:
                word = pick1(appeared_words)
                d_input_ids, d_input_mask = self.dict[word.word]
                d_location_ids = word.location
            else:
                d_input_ids = [0] * self.max_def_length
                d_input_mask = [0] * self.max_def_length
                d_location_ids = [0] * self.max_d_loc

            e = input_ids, input_mask, segment_ids, d_input_ids, d_input_mask, d_location_ids, y
            data.append(e)
        return get_batches_ex(data, batch_size, 7)[0]
Ejemplo n.º 14
0
 def work(self, job_id):
     input_path = os.path.join(self.input_dir, str(job_id))
     save_path = os.path.join(self.out_dir, str(job_id))
     all_feature_list = [
         "input_ids", "input_mask", "segment_ids", "data_id"
     ]
     data = tfrecord_to_old_stype(input_path, all_feature_list)
     data_wo_id = list([(e[0], e[1], e[2]) for e in data])
     data_ids = list([e[3][0] for e in data])
     batches = get_batches_ex(data_wo_id, self.hparam.batch_size, 3)
     tprint("predict")
     ex_logits = self.predictor.predict_ex(self.tag, batches)
     tprint("done")
     assert len(ex_logits) == len(data)
     assert len(ex_logits) == len(data_ids)
     output_dict = {}
     for data_id, data_enry, scores in zip(data_ids, data, ex_logits):
         data_enry.append(scores)
         output_dict[int(data_id)] = list(data_enry)
     pickle.dump(output_dict, open(save_path, "wb"))
Ejemplo n.º 15
0
    def eval_tag():
        if data_loader is None:
            return
        print("Eval")
        for label_idx, tag in enumerate(tags):
            print(tag)
            enc_explain_dev, explain_dev = explain_dev_data_list[tag]
            batches = get_batches_ex(enc_explain_dev, hparam.batch_size, 3)
            try:
                ex_logit_list = []
                for batch in batches:
                    ex_logits, = sess.run([
                        ex_score_tensor[label_idx],
                    ],
                                          feed_dict=batch2feed_dict(batch))
                    print(ex_logits.shape)
                    ex_logit_list.append(ex_logits)

                ex_logit_list = np.concatenate(ex_logit_list, axis=0)
                print(ex_logit_list.shape)
                assert len(ex_logit_list) == len(explain_dev)

                scores = eval_explain(ex_logit_list, data_loader, tag)

                for metric in scores.keys():
                    print("{}\t{}".format(metric, scores[metric]))

                p_at_1, MAP_score = scores["P@1"], scores["MAP"]
                summary = tf.Summary()
                summary.value.add(tag='{}_P@1'.format(tag),
                                  simple_value=p_at_1)
                summary.value.add(tag='{}_MAP'.format(tag),
                                  simple_value=MAP_score)
                train_writer.add_summary(summary, fetch_global_step())
                train_writer.flush()
            except ValueError as e:
                print(e)
                for ex_logits in ex_logit_list:
                    print(ex_logits.shape)
Ejemplo n.º 16
0
def get_batches(file_path, nli_setting: BertNLI, batch_size):
    voca_path = os.path.join(data_path, nli_setting.vocab_filename)
    data_loader = LightDataLoader(nli_setting.seq_length, voca_path)
    data = list(data_loader.example_generator(file_path))
    return get_batches_ex(data, batch_size, 4)
Ejemplo n.º 17
0
 def get_dev_data(self, batch_size):
     result = self.get_data(batch_size * 10)
     batches = get_batches_ex(result, batch_size, 3)
     return batches
Ejemplo n.º 18
0
 def get_train_batch(self, batch_size):
     result = self.get_data(batch_size)
     batches = get_batches_ex(result, batch_size, 3)
     return batches[0]
Ejemplo n.º 19
0
 def predict_both_from_insts(self, explain_tag, insts):
     batches = get_batches_ex(insts, self.batch_size, 3)
     r = self.predict_both(explain_tag, batches)
     return r
Ejemplo n.º 20
0
def get_nli_batches_from_data_loader(data_loader, batch_size):
    train_batches = get_batches_ex(data_loader.get_train_data(), batch_size, 4)
    dev_batches = get_batches_ex(data_loader.get_dev_data(), batch_size, 4)
    return train_batches, dev_batches