Ejemplo n.º 1
0
def label_predict(hparam, data, model_path) -> List[np.array]:
    tprint("building model")
    voca_size = 30522
    task = transformer_logit(hparam, 2, voca_size, False)
    enc_payload: List[Tuple[np.array, np.array, np.array]] = data

    sout = tf.nn.softmax(task.logits, axis=-1)
    sess = init_session()
    sess.run(tf.global_variables_initializer())

    tprint("loading model")
    load_model(sess, model_path)

    def forward_run(inputs):
        batches = get_batches_ex(inputs, hparam.batch_size, 3)
        logit_list = []
        ticker = TimeEstimator(len(batches))
        for batch in batches:
            x0, x1, x2 = batch
            soft_out, = sess.run([
                sout,
            ],
                                 feed_dict={
                                     task.x_list[0]: x0,
                                     task.x_list[1]: x1,
                                     task.x_list[2]: x2,
                                 })
            logit_list.append(soft_out)
            ticker.tick()
        return np.concatenate(logit_list)

    logits = forward_run(enc_payload)
    return logits
Ejemplo n.º 2
0
def predict_for_view(hparam, nli_setting, data_loader, data_id, model_path,
                     run_name, modeling_option, tags):
    print("predict_nli_ex")
    print("Modeling option: ", modeling_option)
    enc_payload, plain_payload = data_loader.load_plain_text(data_id)
    batches = get_batches_ex(enc_payload, hparam.batch_size, 3)

    task = transformer_nli_pooled(hparam, nli_setting.vocab_size)

    explain_predictor = ExplainPredictor(len(tags),
                                         task.model.get_sequence_output(),
                                         modeling_option)
    sess = init_session()
    sess.run(tf.global_variables_initializer())
    load_model(sess, model_path)

    out_entries = []
    for batch in batches:
        x0, x1, x2 = batch
        logits, ex_logits, = sess.run(
            [task.logits, explain_predictor.get_score()],
            feed_dict={
                task.x_list[0]: x0,
                task.x_list[1]: x1,
                task.x_list[2]: x2,
            })

        for i in range(len(x0)):
            e = x0[i], logits[i], tuple_list_select(ex_logits, i)
            out_entries.append(e)

    save_to_pickle(out_entries, "save_view_{}_{}".format(run_name, data_id))
Ejemplo n.º 3
0
def nli_attribution_predict(hparam, nli_setting, data_loader,
                            explain_tag, method_name, data_id, sub_range, model_path):
    enc_payload, plain_payload = data_loader.load_plain_text(data_id)
    if sub_range is not None:
        raise Exception("Sub_range is not supported")


    from attribution.gradient import explain_by_gradient
    from attribution.deepexplain.tensorflow import DeepExplain

    sess = init_session()

    with DeepExplain(session=sess, graph=sess.graph) as de:
        task = transformer_nli_pooled_embedding_in(hparam, nli_setting.vocab_size, False)
        softmax_out = tf.nn.softmax(task.logits, axis=-1)
        sess.run(tf.global_variables_initializer())
        load_model(sess, model_path)
        emb_outputs = task.encoded_embedding_out, task.attention_mask_out
        emb_input = task.encoded_embedding_in, task.attention_mask_in

        def feed_end_input(batch):
            x0, x1, x2 = batch
            return {task.x_list[0]:x0,
                    task.x_list[1]:x1,
                    task.x_list[2]:x2,
                    }

        explains = explain_by_gradient(enc_payload, method_name, explain_tag, sess, de,
                                       feed_end_input, emb_outputs, emb_input, softmax_out)

        pred_list = predict_translate(explains, data_loader, enc_payload, plain_payload)
        save_to_pickle(pred_list, "pred_{}_{}".format(method_name, data_id))
Ejemplo n.º 4
0
def eval_fidelity_gradient(hparam, nli_setting, flat_dev_batches, explain_tag,
                           method_name, model_path):

    from attribution.gradient import explain_by_gradient
    from attribution.deepexplain.tensorflow import DeepExplain

    sess = init_session()

    with DeepExplain(session=sess, graph=sess.graph) as de:

        task = transformer_nli_pooled_embedding_in(hparam,
                                                   nli_setting.vocab_size,
                                                   False)
        softmax_out = tf.nn.softmax(task.logits, axis=-1)
        sess.run(tf.global_variables_initializer())
        load_model(sess, model_path)
        emb_outputs = task.encoded_embedding_out, task.attention_mask_out
        emb_input = task.encoded_embedding_in, task.attention_mask_in

        def feed_end_input(batch):
            x0, x1, x2 = batch
            return {
                task.x_list[0]: x0,
                task.x_list[1]: x1,
                task.x_list[2]: x2,
            }

        def forward_runs(insts):
            alt_batches = get_batches_ex(insts, hparam.batch_size, 3)
            alt_logits = []
            for batch in alt_batches:
                enc, att = sess.run(emb_outputs,
                                    feed_dict=feed_end_input(batch))
                logits, = sess.run([
                    softmax_out,
                ],
                                   feed_dict={
                                       task.encoded_embedding_in: enc,
                                       task.attention_mask_in: att
                                   })

                alt_logits.append(logits)
            alt_logits = np.concatenate(alt_logits)
            return alt_logits

        contrib_score = explain_by_gradient(flat_dev_batches, method_name,
                                            explain_tag, sess, de,
                                            feed_end_input, emb_outputs,
                                            emb_input, softmax_out)
        print("contrib_score", len(contrib_score))
        print("flat_dev_batches", len(flat_dev_batches))

        acc_list = eval_fidelity(contrib_score, flat_dev_batches, forward_runs,
                                 explain_tag)

        return acc_list
Ejemplo n.º 5
0
def baseline_predict(hparam, nli_setting, data, method_name,
                     model_path) -> List[np.array]:
    tprint("building model")
    voca_size = 30522
    task = transformer_logit(hparam, 2, voca_size, False)
    enc_payload: List[Tuple[np.array, np.array, np.array]] = data

    sout = tf.nn.softmax(task.logits, axis=-1)
    sess = init_session()
    sess.run(tf.global_variables_initializer())

    tprint("loading model")
    load_model(sess, model_path)

    def forward_run(inputs):
        batches = get_batches_ex(inputs, hparam.batch_size, 3)
        logit_list = []
        ticker = TimeEstimator(len(batches))
        for batch in batches:
            x0, x1, x2 = batch
            soft_out, = sess.run([
                sout,
            ],
                                 feed_dict={
                                     task.x_list[0]: x0,
                                     task.x_list[1]: x1,
                                     task.x_list[2]: x2,
                                 })
            logit_list.append(soft_out)
            ticker.tick()
        return np.concatenate(logit_list)

    # train_batches, dev_batches = self.load_nli_data(data_loader)
    def idf_explain(enc_payload, explain_tag, forward_run):
        train_batches, dev_batches = get_nli_data(hparam, nli_setting)
        idf_scorer = IdfScorer(train_batches)
        return idf_scorer.explain(enc_payload, explain_tag, forward_run)

    todo_list = [
        ('deletion_seq', explain_by_seq_deletion),
        ('replace_token', explain_by_replace),
        ('term_deletion', explain_by_term_deletion),
        ('term_replace', explain_by_term_replace),
        ('random', explain_by_random),
        ('idf', idf_explain),
        ('deletion', explain_by_deletion),
        ('LIME', explain_by_lime),
    ]
    method_dict = dict(todo_list)
    method = method_dict[method_name]
    explain_tag = "mismatch"
    explains: List[np.array] = method(enc_payload, explain_tag, forward_run)
    # pred_list = predict_translate(explains, data_loader, enc_payload, plain_payload)
    return explains
Ejemplo n.º 6
0
def nli_baseline_predict(hparam, nli_setting, data_loader, explain_tag,
                         method_name, data_id, sub_range, model_path):
    enc_payload, plain_payload = data_loader.load_plain_text(data_id)
    assert enc_payload is not None
    assert plain_payload is not None

    name_format = "pred_{}_" + data_id
    if sub_range is not None:
        st, ed = [int(t) for t in sub_range.split(",")]
        enc_payload = enc_payload[st:ed]
        plain_payload = plain_payload[st:ed]
        name_format = "pred_{}_" + data_id + "__{}_{}".format(st, ed)
        print(name_format)
    task = transformer_nli_pooled(hparam, nli_setting.vocab_size)
    sout = tf.nn.softmax(task.logits, axis=-1)
    sess = init_session()
    sess.run(tf.global_variables_initializer())
    load_model(sess, model_path)

    def forward_run(inputs):
        batches = get_batches_ex(inputs, hparam.batch_size, 3)
        logit_list = []
        for batch in batches:
            x0, x1, x2 = batch
            soft_out, = sess.run([
                sout,
            ],
                                 feed_dict={
                                     task.x_list[0]: x0,
                                     task.x_list[1]: x1,
                                     task.x_list[2]: x2,
                                 })
            logit_list.append(soft_out)
        return np.concatenate(logit_list)

    # train_batches, dev_batches = self.load_nli_data(data_loader)
    def idf_explain(enc_payload, explain_tag, forward_run):
        train_batches, dev_batches = get_nli_data(hparam, nli_setting)
        idf_scorer = IdfScorer(train_batches)
        return idf_scorer.explain(enc_payload, explain_tag, forward_run)

    todo_list = [
        ('deletion_seq', explain_by_seq_deletion),
        ('random', explain_by_random),
        ('idf', idf_explain),
        ('deletion', explain_by_deletion),
        ('LIME', explain_by_lime),
    ]
    method_dict = dict(todo_list)
    method = method_dict[method_name]
    explains = method(enc_payload, explain_tag, forward_run)
    pred_list = predict_translate(explains, data_loader, enc_payload,
                                  plain_payload)
    save_to_pickle(pred_list, name_format.format(method_name))
Ejemplo n.º 7
0
 def __init__(self,
              hparam,
              nli_setting,
              model_path,
              modeling_option,
              tags_list=nli_info.tags):
     self.num_tags = len(tags_list)
     self.tags = tags_list
     self.define_graph(hparam, nli_setting, modeling_option)
     self.sess = init_session()
     self.sess.run(tf.global_variables_initializer())
     self.batch_size = hparam.batch_size
     load_model(self.sess, model_path)
Ejemplo n.º 8
0
 def __init__(self, hparam, nli_setting, model_path, method_name):
     self.task = transformer_nli_pooled(hparam, nli_setting.vocab_size,
                                        False)
     self.sout = tf.nn.softmax(self.task.logits, axis=-1)
     self.sess = init_session()
     self.sess.run(tf.global_variables_initializer())
     self.batch_size = hparam.batch_size
     load_model(self.sess, model_path)
     todo_list = [
         ('deletion_seq', explain_by_seq_deletion),
         ('random', explain_by_random),
         ('deletion', explain_by_deletion),
         ('LIME', explain_by_lime),
     ]
     method_dict = dict(todo_list)
     self.method = method_dict[method_name]
Ejemplo n.º 9
0
def init_fn_generic(sess, start_type, start_model_path):
    if start_type == "cls":
        load_model_with_blacklist(sess, start_model_path,
                                  ["explain", "explain_optimizer"])
    elif start_type == "cls_new":
        load_model_with_blacklist(
            sess, start_model_path,
            ["explain", "explain_optimizer", "optimizer"])
    elif start_type == "cls_ex":
        load_model(sess, start_model_path)
    elif start_type == "as_is":
        load_model(sess, start_model_path)
    elif start_type == "cls_ex_for_pairing":
        load_model_with_blacklist(sess, start_model_path,
                                  ["match_predictor", "match_optimizer"])
    elif start_type == "bert":
        load_model_w_scope(sess, start_model_path, ["bert"])
    elif start_type == "cold":
        pass
    else:
        assert False
Ejemplo n.º 10
0
 def load_fn(sess, model_path):
     return load_model(sess, model_path)
Ejemplo n.º 11
0
 def load_last_saved_model(self, model_path):
     last_saved = get_latest_model_path_from_dir_path(model_path)
     load_model(self.sess, last_saved)
     tf_logging.info("Loading previous model from {}".format(last_saved))
Ejemplo n.º 12
0
def train_nli_w_dict(run_name, model: DictReaderInterface, model_path,
                     model_config, data_feeder_loader, model_init_fn):
    print("Train nil :", run_name)
    batch_size = FLAGS.train_batch_size
    f_train_lookup = "lookup" in FLAGS.train_op
    tf_logging.debug("Building graph")

    with tf.compat.v1.variable_scope("optimizer"):
        lr = FLAGS.learning_rate
        lr2 = lr * 0.1
        if model_config.compare_attrib_value_safe("use_two_lr", True):
            tf_logging.info("Using two lr for each parts")
            train_cls, global_step = get_train_op_sep_lr(
                model.get_cls_loss(), lr, 5, "dict")
        else:
            train_cls, global_step = train_module.get_train_op(
                model.get_cls_loss(), lr)
        train_lookup_op, global_step = train_module.get_train_op(
            model.get_lookup_loss(), lr2, global_step)

    sess = train_module.init_session()
    sess.run(tf.compat.v1.global_variables_initializer())

    train_writer, test_writer = setup_summary_writer(run_name)

    last_saved = get_latest_model_path_from_dir_path(model_path)
    if last_saved:
        tf_logging.info("Loading previous model from {}".format(last_saved))
        load_model(sess, last_saved)
    elif model_init_fn is not None:
        model_init_fn(sess)

    log = log_module.train_logger()
    train_data_feeder = data_feeder_loader.get_train_feeder()
    dev_data_feeder = data_feeder_loader.get_dev_feeder()
    lookup_train_feeder = train_data_feeder
    valid_runner = WSSDRRunner(model, dev_data_feeder.augment_dict_info, sess)

    dev_batches = []
    n_dev_batch = 100
    dev_batches_w_dict = dev_data_feeder.get_all_batches(batch_size,
                                                         True)[:n_dev_batch]
    for _ in range(n_dev_batch):
        dev_batches.append(dev_data_feeder.get_random_batch(batch_size))
        dev_batches_w_dict.append(dev_data_feeder.get_lookup_batch(batch_size))

    def get_summary_obj(loss, acc):
        summary = tf.compat.v1.Summary()
        summary.value.add(tag='loss', simple_value=loss)
        summary.value.add(tag='accuracy', simple_value=acc)
        return summary

    def get_summary_obj_lookup(loss, p_at_1):
        summary = tf.compat.v1.Summary()
        summary.value.add(tag='lookup_loss', simple_value=loss)
        summary.value.add(tag='P@1', simple_value=p_at_1)
        return summary

    def train_lookup(step_i):
        batches, info = lookup_train_feeder.get_lookup_train_batches(
            batch_size)
        if not batches:
            raise NoLookupException()

        def get_cls_loss(batch):
            return sess.run([model.get_cls_loss_arr()],
                            feed_dict=model.batch2feed_dict(batch))

        loss_array = get_loss_from_batches(batches, get_cls_loss)

        supervision_for_lookup = train_data_feeder.get_lookup_training_batch(
            loss_array, batch_size, info)

        def lookup_train(batch):
            return sess.run(
                [model.get_lookup_loss(),
                 model.get_p_at_1(), train_lookup_op],
                feed_dict=model.batch2feed_dict(batch))

        avg_loss, p_at_1, _ = lookup_train(supervision_for_lookup)
        train_writer.add_summary(get_summary_obj_lookup(avg_loss, p_at_1),
                                 step_i)
        log.info("Step {0} lookup loss={1:.04f}".format(step_i, avg_loss))
        return avg_loss

    def train_classification(step_i):
        batch = train_data_feeder.get_random_batch(batch_size)
        loss_val, acc, _ = sess.run(
            [model.get_cls_loss(),
             model.get_acc(), train_cls],
            feed_dict=model.batch2feed_dict(batch))
        log.info("Step {0} train loss={1:.04f} acc={2:.03f}".format(
            step_i, loss_val, acc))
        train_writer.add_summary(get_summary_obj(loss_val, acc), step_i)

        return loss_val, acc

    lookup_loss_window = MovingWindow(20)

    def train_classification_w_lookup(step_i):
        data_indices, batch = train_data_feeder.get_lookup_batch(batch_size)
        logits, = sess.run([model.get_lookup_logits()],
                           feed_dict=model.batch2feed_dict(batch))
        term_ranks = np.flip(np.argsort(logits[:, :, 1], axis=1))
        batch = train_data_feeder.augment_dict_info(data_indices, term_ranks)

        loss_val, acc, _ = sess.run(
            [model.get_cls_loss(),
             model.get_acc(), train_cls],
            feed_dict=model.batch2feed_dict(batch))
        log.info("ClsW]Step {0} train loss={1:.04f} acc={2:.03f}".format(
            step_i, loss_val, acc))
        train_writer.add_summary(get_summary_obj(loss_val, acc), step_i)

        return loss_val, acc

    def lookup_enabled(lookup_loss_window, step_i):
        return step_i > model_config.lookup_min_step\
               and lookup_loss_window.get_average() < model_config.lookup_threshold

    def train_fn(step_i):
        if lookup_enabled(lookup_loss_window, step_i):
            loss, acc = train_classification_w_lookup((step_i))
        else:
            loss, acc = train_classification(step_i)
        if f_train_lookup and step_i % model_config.lookup_train_frequency == 0:
            try:
                lookup_loss = train_lookup(step_i)
                lookup_loss_window.append(lookup_loss, 1)
            except NoLookupException:
                log.warning("No possible lookup found")

        return loss, acc

    def debug_fn(batch):
        y_lookup, = sess.run([
            model.y_lookup,
        ],
                             feed_dict=model.batch2feed_dict(batch))
        print(y_lookup)
        return 0, 0

    def valid_fn(step_i):
        if lookup_enabled(lookup_loss_window, step_i):
            valid_fn_w_lookup(step_i)
        else:
            valid_fn_wo_lookup(step_i)

    def valid_fn_wo_lookup(step_i):
        loss_val, acc = valid_runner.run_batches_wo_lookup(dev_batches)
        log.info("Step {0} Dev loss={1:.04f} acc={2:.03f}".format(
            step_i, loss_val, acc))
        test_writer.add_summary(get_summary_obj(loss_val, acc), step_i)
        return acc

    def valid_fn_w_lookup(step_i):
        loss_val, acc = valid_runner.run_batches_w_lookup(dev_batches_w_dict)
        log.info("Step {0} DevW loss={1:.04f} acc={2:.03f}".format(
            step_i, loss_val, acc))
        test_writer.add_summary(get_summary_obj(loss_val, acc), step_i)
        return acc

    def save_fn():
        op = tf.compat.v1.assign(global_step, step_i)
        sess.run([op])
        return save_model_to_dir_path(sess, model_path, global_step)

    n_data = train_data_feeder.get_data_len()
    step_per_epoch = int((n_data + batch_size - 1) / batch_size)
    tf_logging.debug("{} data point -> {} batches / epoch".format(
        n_data, step_per_epoch))
    train_steps = step_per_epoch * FLAGS.num_train_epochs
    tf_logging.debug("Max train step : {}".format(train_steps))
    valid_freq = 100
    save_interval = 60 * 20
    last_save = time.time()

    init_step, = sess.run([global_step])
    print("Initial step : ", init_step)
    for step_i in range(init_step, train_steps):
        if dev_fn is not None:
            if (step_i + 1) % valid_freq == 0:
                valid_fn(step_i)

        if save_fn is not None:
            if time.time() - last_save > save_interval:
                save_fn()
                last_save = time.time()

        loss, acc = train_fn(step_i)

    return save_fn()
Ejemplo n.º 13
0
 def load_fn(sess, model_path):
     if not resume:
         return load_model_w_scope(sess, model_path, "bert")
     else:
         return load_model(sess, model_path)