def label_predict(hparam, data, model_path) -> List[np.array]: tprint("building model") voca_size = 30522 task = transformer_logit(hparam, 2, voca_size, False) enc_payload: List[Tuple[np.array, np.array, np.array]] = data sout = tf.nn.softmax(task.logits, axis=-1) sess = init_session() sess.run(tf.global_variables_initializer()) tprint("loading model") load_model(sess, model_path) def forward_run(inputs): batches = get_batches_ex(inputs, hparam.batch_size, 3) logit_list = [] ticker = TimeEstimator(len(batches)) for batch in batches: x0, x1, x2 = batch soft_out, = sess.run([ sout, ], feed_dict={ task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, }) logit_list.append(soft_out) ticker.tick() return np.concatenate(logit_list) logits = forward_run(enc_payload) return logits
def predict_for_view(hparam, nli_setting, data_loader, data_id, model_path, run_name, modeling_option, tags): print("predict_nli_ex") print("Modeling option: ", modeling_option) enc_payload, plain_payload = data_loader.load_plain_text(data_id) batches = get_batches_ex(enc_payload, hparam.batch_size, 3) task = transformer_nli_pooled(hparam, nli_setting.vocab_size) explain_predictor = ExplainPredictor(len(tags), task.model.get_sequence_output(), modeling_option) sess = init_session() sess.run(tf.global_variables_initializer()) load_model(sess, model_path) out_entries = [] for batch in batches: x0, x1, x2 = batch logits, ex_logits, = sess.run( [task.logits, explain_predictor.get_score()], feed_dict={ task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, }) for i in range(len(x0)): e = x0[i], logits[i], tuple_list_select(ex_logits, i) out_entries.append(e) save_to_pickle(out_entries, "save_view_{}_{}".format(run_name, data_id))
def nli_attribution_predict(hparam, nli_setting, data_loader, explain_tag, method_name, data_id, sub_range, model_path): enc_payload, plain_payload = data_loader.load_plain_text(data_id) if sub_range is not None: raise Exception("Sub_range is not supported") from attribution.gradient import explain_by_gradient from attribution.deepexplain.tensorflow import DeepExplain sess = init_session() with DeepExplain(session=sess, graph=sess.graph) as de: task = transformer_nli_pooled_embedding_in(hparam, nli_setting.vocab_size, False) softmax_out = tf.nn.softmax(task.logits, axis=-1) sess.run(tf.global_variables_initializer()) load_model(sess, model_path) emb_outputs = task.encoded_embedding_out, task.attention_mask_out emb_input = task.encoded_embedding_in, task.attention_mask_in def feed_end_input(batch): x0, x1, x2 = batch return {task.x_list[0]:x0, task.x_list[1]:x1, task.x_list[2]:x2, } explains = explain_by_gradient(enc_payload, method_name, explain_tag, sess, de, feed_end_input, emb_outputs, emb_input, softmax_out) pred_list = predict_translate(explains, data_loader, enc_payload, plain_payload) save_to_pickle(pred_list, "pred_{}_{}".format(method_name, data_id))
def eval_fidelity_gradient(hparam, nli_setting, flat_dev_batches, explain_tag, method_name, model_path): from attribution.gradient import explain_by_gradient from attribution.deepexplain.tensorflow import DeepExplain sess = init_session() with DeepExplain(session=sess, graph=sess.graph) as de: task = transformer_nli_pooled_embedding_in(hparam, nli_setting.vocab_size, False) softmax_out = tf.nn.softmax(task.logits, axis=-1) sess.run(tf.global_variables_initializer()) load_model(sess, model_path) emb_outputs = task.encoded_embedding_out, task.attention_mask_out emb_input = task.encoded_embedding_in, task.attention_mask_in def feed_end_input(batch): x0, x1, x2 = batch return { task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, } def forward_runs(insts): alt_batches = get_batches_ex(insts, hparam.batch_size, 3) alt_logits = [] for batch in alt_batches: enc, att = sess.run(emb_outputs, feed_dict=feed_end_input(batch)) logits, = sess.run([ softmax_out, ], feed_dict={ task.encoded_embedding_in: enc, task.attention_mask_in: att }) alt_logits.append(logits) alt_logits = np.concatenate(alt_logits) return alt_logits contrib_score = explain_by_gradient(flat_dev_batches, method_name, explain_tag, sess, de, feed_end_input, emb_outputs, emb_input, softmax_out) print("contrib_score", len(contrib_score)) print("flat_dev_batches", len(flat_dev_batches)) acc_list = eval_fidelity(contrib_score, flat_dev_batches, forward_runs, explain_tag) return acc_list
def baseline_predict(hparam, nli_setting, data, method_name, model_path) -> List[np.array]: tprint("building model") voca_size = 30522 task = transformer_logit(hparam, 2, voca_size, False) enc_payload: List[Tuple[np.array, np.array, np.array]] = data sout = tf.nn.softmax(task.logits, axis=-1) sess = init_session() sess.run(tf.global_variables_initializer()) tprint("loading model") load_model(sess, model_path) def forward_run(inputs): batches = get_batches_ex(inputs, hparam.batch_size, 3) logit_list = [] ticker = TimeEstimator(len(batches)) for batch in batches: x0, x1, x2 = batch soft_out, = sess.run([ sout, ], feed_dict={ task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, }) logit_list.append(soft_out) ticker.tick() return np.concatenate(logit_list) # train_batches, dev_batches = self.load_nli_data(data_loader) def idf_explain(enc_payload, explain_tag, forward_run): train_batches, dev_batches = get_nli_data(hparam, nli_setting) idf_scorer = IdfScorer(train_batches) return idf_scorer.explain(enc_payload, explain_tag, forward_run) todo_list = [ ('deletion_seq', explain_by_seq_deletion), ('replace_token', explain_by_replace), ('term_deletion', explain_by_term_deletion), ('term_replace', explain_by_term_replace), ('random', explain_by_random), ('idf', idf_explain), ('deletion', explain_by_deletion), ('LIME', explain_by_lime), ] method_dict = dict(todo_list) method = method_dict[method_name] explain_tag = "mismatch" explains: List[np.array] = method(enc_payload, explain_tag, forward_run) # pred_list = predict_translate(explains, data_loader, enc_payload, plain_payload) return explains
def nli_baseline_predict(hparam, nli_setting, data_loader, explain_tag, method_name, data_id, sub_range, model_path): enc_payload, plain_payload = data_loader.load_plain_text(data_id) assert enc_payload is not None assert plain_payload is not None name_format = "pred_{}_" + data_id if sub_range is not None: st, ed = [int(t) for t in sub_range.split(",")] enc_payload = enc_payload[st:ed] plain_payload = plain_payload[st:ed] name_format = "pred_{}_" + data_id + "__{}_{}".format(st, ed) print(name_format) task = transformer_nli_pooled(hparam, nli_setting.vocab_size) sout = tf.nn.softmax(task.logits, axis=-1) sess = init_session() sess.run(tf.global_variables_initializer()) load_model(sess, model_path) def forward_run(inputs): batches = get_batches_ex(inputs, hparam.batch_size, 3) logit_list = [] for batch in batches: x0, x1, x2 = batch soft_out, = sess.run([ sout, ], feed_dict={ task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, }) logit_list.append(soft_out) return np.concatenate(logit_list) # train_batches, dev_batches = self.load_nli_data(data_loader) def idf_explain(enc_payload, explain_tag, forward_run): train_batches, dev_batches = get_nli_data(hparam, nli_setting) idf_scorer = IdfScorer(train_batches) return idf_scorer.explain(enc_payload, explain_tag, forward_run) todo_list = [ ('deletion_seq', explain_by_seq_deletion), ('random', explain_by_random), ('idf', idf_explain), ('deletion', explain_by_deletion), ('LIME', explain_by_lime), ] method_dict = dict(todo_list) method = method_dict[method_name] explains = method(enc_payload, explain_tag, forward_run) pred_list = predict_translate(explains, data_loader, enc_payload, plain_payload) save_to_pickle(pred_list, name_format.format(method_name))
def __init__(self, hparam, nli_setting, model_path, modeling_option, tags_list=nli_info.tags): self.num_tags = len(tags_list) self.tags = tags_list self.define_graph(hparam, nli_setting, modeling_option) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) self.batch_size = hparam.batch_size load_model(self.sess, model_path)
def __init__(self, hparam, nli_setting, model_path, method_name): self.task = transformer_nli_pooled(hparam, nli_setting.vocab_size, False) self.sout = tf.nn.softmax(self.task.logits, axis=-1) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) self.batch_size = hparam.batch_size load_model(self.sess, model_path) todo_list = [ ('deletion_seq', explain_by_seq_deletion), ('random', explain_by_random), ('deletion', explain_by_deletion), ('LIME', explain_by_lime), ] method_dict = dict(todo_list) self.method = method_dict[method_name]
def init_fn_generic(sess, start_type, start_model_path): if start_type == "cls": load_model_with_blacklist(sess, start_model_path, ["explain", "explain_optimizer"]) elif start_type == "cls_new": load_model_with_blacklist( sess, start_model_path, ["explain", "explain_optimizer", "optimizer"]) elif start_type == "cls_ex": load_model(sess, start_model_path) elif start_type == "as_is": load_model(sess, start_model_path) elif start_type == "cls_ex_for_pairing": load_model_with_blacklist(sess, start_model_path, ["match_predictor", "match_optimizer"]) elif start_type == "bert": load_model_w_scope(sess, start_model_path, ["bert"]) elif start_type == "cold": pass else: assert False
def load_fn(sess, model_path): return load_model(sess, model_path)
def load_last_saved_model(self, model_path): last_saved = get_latest_model_path_from_dir_path(model_path) load_model(self.sess, last_saved) tf_logging.info("Loading previous model from {}".format(last_saved))
def train_nli_w_dict(run_name, model: DictReaderInterface, model_path, model_config, data_feeder_loader, model_init_fn): print("Train nil :", run_name) batch_size = FLAGS.train_batch_size f_train_lookup = "lookup" in FLAGS.train_op tf_logging.debug("Building graph") with tf.compat.v1.variable_scope("optimizer"): lr = FLAGS.learning_rate lr2 = lr * 0.1 if model_config.compare_attrib_value_safe("use_two_lr", True): tf_logging.info("Using two lr for each parts") train_cls, global_step = get_train_op_sep_lr( model.get_cls_loss(), lr, 5, "dict") else: train_cls, global_step = train_module.get_train_op( model.get_cls_loss(), lr) train_lookup_op, global_step = train_module.get_train_op( model.get_lookup_loss(), lr2, global_step) sess = train_module.init_session() sess.run(tf.compat.v1.global_variables_initializer()) train_writer, test_writer = setup_summary_writer(run_name) last_saved = get_latest_model_path_from_dir_path(model_path) if last_saved: tf_logging.info("Loading previous model from {}".format(last_saved)) load_model(sess, last_saved) elif model_init_fn is not None: model_init_fn(sess) log = log_module.train_logger() train_data_feeder = data_feeder_loader.get_train_feeder() dev_data_feeder = data_feeder_loader.get_dev_feeder() lookup_train_feeder = train_data_feeder valid_runner = WSSDRRunner(model, dev_data_feeder.augment_dict_info, sess) dev_batches = [] n_dev_batch = 100 dev_batches_w_dict = dev_data_feeder.get_all_batches(batch_size, True)[:n_dev_batch] for _ in range(n_dev_batch): dev_batches.append(dev_data_feeder.get_random_batch(batch_size)) dev_batches_w_dict.append(dev_data_feeder.get_lookup_batch(batch_size)) def get_summary_obj(loss, acc): summary = tf.compat.v1.Summary() summary.value.add(tag='loss', simple_value=loss) summary.value.add(tag='accuracy', simple_value=acc) return summary def get_summary_obj_lookup(loss, p_at_1): summary = tf.compat.v1.Summary() summary.value.add(tag='lookup_loss', simple_value=loss) summary.value.add(tag='P@1', simple_value=p_at_1) return summary def train_lookup(step_i): batches, info = lookup_train_feeder.get_lookup_train_batches( batch_size) if not batches: raise NoLookupException() def get_cls_loss(batch): return sess.run([model.get_cls_loss_arr()], feed_dict=model.batch2feed_dict(batch)) loss_array = get_loss_from_batches(batches, get_cls_loss) supervision_for_lookup = train_data_feeder.get_lookup_training_batch( loss_array, batch_size, info) def lookup_train(batch): return sess.run( [model.get_lookup_loss(), model.get_p_at_1(), train_lookup_op], feed_dict=model.batch2feed_dict(batch)) avg_loss, p_at_1, _ = lookup_train(supervision_for_lookup) train_writer.add_summary(get_summary_obj_lookup(avg_loss, p_at_1), step_i) log.info("Step {0} lookup loss={1:.04f}".format(step_i, avg_loss)) return avg_loss def train_classification(step_i): batch = train_data_feeder.get_random_batch(batch_size) loss_val, acc, _ = sess.run( [model.get_cls_loss(), model.get_acc(), train_cls], feed_dict=model.batch2feed_dict(batch)) log.info("Step {0} train loss={1:.04f} acc={2:.03f}".format( step_i, loss_val, acc)) train_writer.add_summary(get_summary_obj(loss_val, acc), step_i) return loss_val, acc lookup_loss_window = MovingWindow(20) def train_classification_w_lookup(step_i): data_indices, batch = train_data_feeder.get_lookup_batch(batch_size) logits, = sess.run([model.get_lookup_logits()], feed_dict=model.batch2feed_dict(batch)) term_ranks = np.flip(np.argsort(logits[:, :, 1], axis=1)) batch = train_data_feeder.augment_dict_info(data_indices, term_ranks) loss_val, acc, _ = sess.run( [model.get_cls_loss(), model.get_acc(), train_cls], feed_dict=model.batch2feed_dict(batch)) log.info("ClsW]Step {0} train loss={1:.04f} acc={2:.03f}".format( step_i, loss_val, acc)) train_writer.add_summary(get_summary_obj(loss_val, acc), step_i) return loss_val, acc def lookup_enabled(lookup_loss_window, step_i): return step_i > model_config.lookup_min_step\ and lookup_loss_window.get_average() < model_config.lookup_threshold def train_fn(step_i): if lookup_enabled(lookup_loss_window, step_i): loss, acc = train_classification_w_lookup((step_i)) else: loss, acc = train_classification(step_i) if f_train_lookup and step_i % model_config.lookup_train_frequency == 0: try: lookup_loss = train_lookup(step_i) lookup_loss_window.append(lookup_loss, 1) except NoLookupException: log.warning("No possible lookup found") return loss, acc def debug_fn(batch): y_lookup, = sess.run([ model.y_lookup, ], feed_dict=model.batch2feed_dict(batch)) print(y_lookup) return 0, 0 def valid_fn(step_i): if lookup_enabled(lookup_loss_window, step_i): valid_fn_w_lookup(step_i) else: valid_fn_wo_lookup(step_i) def valid_fn_wo_lookup(step_i): loss_val, acc = valid_runner.run_batches_wo_lookup(dev_batches) log.info("Step {0} Dev loss={1:.04f} acc={2:.03f}".format( step_i, loss_val, acc)) test_writer.add_summary(get_summary_obj(loss_val, acc), step_i) return acc def valid_fn_w_lookup(step_i): loss_val, acc = valid_runner.run_batches_w_lookup(dev_batches_w_dict) log.info("Step {0} DevW loss={1:.04f} acc={2:.03f}".format( step_i, loss_val, acc)) test_writer.add_summary(get_summary_obj(loss_val, acc), step_i) return acc def save_fn(): op = tf.compat.v1.assign(global_step, step_i) sess.run([op]) return save_model_to_dir_path(sess, model_path, global_step) n_data = train_data_feeder.get_data_len() step_per_epoch = int((n_data + batch_size - 1) / batch_size) tf_logging.debug("{} data point -> {} batches / epoch".format( n_data, step_per_epoch)) train_steps = step_per_epoch * FLAGS.num_train_epochs tf_logging.debug("Max train step : {}".format(train_steps)) valid_freq = 100 save_interval = 60 * 20 last_save = time.time() init_step, = sess.run([global_step]) print("Initial step : ", init_step) for step_i in range(init_step, train_steps): if dev_fn is not None: if (step_i + 1) % valid_freq == 0: valid_fn(step_i) if save_fn is not None: if time.time() - last_save > save_interval: save_fn() last_save = time.time() loss, acc = train_fn(step_i) return save_fn()
def load_fn(sess, model_path): if not resume: return load_model_w_scope(sess, model_path, "bert") else: return load_model(sess, model_path)