def eval_nli(hparam, nli_setting, run_name, dev_batches, model_path, load_fn): print("eval_nli :", run_name) task = transformer_nli_pooled(hparam, nli_setting.vocab_size, False) sess = init_session() sess.run(tf.global_variables_initializer()) if model_path is not None: #load_model(sess, model_path) load_fn(sess, model_path) def batch2feed_dict(batch): x0, x1, x2, y = batch feed_dict = { task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, task.y: y, } return feed_dict def valid_fn(): loss_list = [] acc_list = [] for batch in dev_batches: loss_val, acc = sess.run([task.loss, task.acc], feed_dict=batch2feed_dict(batch)) cur_batch_size = len(batch[0]) for _ in range(cur_batch_size): loss_list.append(loss_val) acc_list.append(acc) return average(acc_list) return valid_fn()
def predict_for_view(hparam, nli_setting, data_loader, data_id, model_path, run_name, modeling_option, tags): print("predict_nli_ex") print("Modeling option: ", modeling_option) enc_payload, plain_payload = data_loader.load_plain_text(data_id) batches = get_batches_ex(enc_payload, hparam.batch_size, 3) task = transformer_nli_pooled(hparam, nli_setting.vocab_size) explain_predictor = ExplainPredictor(len(tags), task.model.get_sequence_output(), modeling_option) sess = init_session() sess.run(tf.global_variables_initializer()) load_model(sess, model_path) out_entries = [] for batch in batches: x0, x1, x2 = batch logits, ex_logits, = sess.run( [task.logits, explain_predictor.get_score()], feed_dict={ task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, }) for i in range(len(x0)): e = x0[i], logits[i], tuple_list_select(ex_logits, i) out_entries.append(e) save_to_pickle(out_entries, "save_view_{}_{}".format(run_name, data_id))
def do_predict( bert_hp, train_config, data, lms_config, modeling_option, init_fn, ): num_gpu = train_config.num_gpu train_batches, dev_batches = data lms_model = LMSModel(modeling_option, bert_hp, lms_config, num_gpu) sess = init_session() sess.run(tf.global_variables_initializer()) init_fn(sess) step_size = 100 for i in range(100): st = i * step_size ed = st + step_size # make explain train_op does not increase global step tprint(st, ed) output_d = predict_fn(sess, train_batches[st:ed], lms_model.logits, lms_model.loss_tensor, lms_model.ex_score_tensor, lms_model.per_layer_logit_tensor, lms_model.batch2feed_dict) save_path = at_output_dir("lms_scores", str(i)) save_to_pickle(output_d, save_path)
def fetch_hidden_vector(hparam, vocab_size, data, model_path): task = transformer_nli_hidden(hparam, vocab_size, 0, False) sess = init_session() sess.run(tf.global_variables_initializer()) load_model_w_scope(sess, model_path, ["bert"]) batches = get_batches_ex(data, hparam.batch_size, 4) def batch2feed_dict(batch): x0, x1, x2, y = batch feed_dict = { task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, task.y: y, } return feed_dict def pred_fn(): outputs = [] for batch in batches: x0, x1, x2, y = batch all_layers, emb_outputs = sess.run( [task.all_layers, task.embedding_output], feed_dict=batch2feed_dict(batch)) outputs.append((all_layers, emb_outputs, x0)) return outputs return pred_fn()
def nli_attribution_predict(hparam, nli_setting, data_loader, explain_tag, method_name, data_id, sub_range, model_path): enc_payload, plain_payload = data_loader.load_plain_text(data_id) if sub_range is not None: raise Exception("Sub_range is not supported") from attribution.gradient import explain_by_gradient from attribution.deepexplain.tensorflow import DeepExplain sess = init_session() with DeepExplain(session=sess, graph=sess.graph) as de: task = transformer_nli_pooled_embedding_in(hparam, nli_setting.vocab_size, False) softmax_out = tf.nn.softmax(task.logits, axis=-1) sess.run(tf.global_variables_initializer()) load_model(sess, model_path) emb_outputs = task.encoded_embedding_out, task.attention_mask_out emb_input = task.encoded_embedding_in, task.attention_mask_in def feed_end_input(batch): x0, x1, x2 = batch return {task.x_list[0]:x0, task.x_list[1]:x1, task.x_list[2]:x2, } explains = explain_by_gradient(enc_payload, method_name, explain_tag, sess, de, feed_end_input, emb_outputs, emb_input, softmax_out) pred_list = predict_translate(explains, data_loader, enc_payload, plain_payload) save_to_pickle(pred_list, "pred_{}_{}".format(method_name, data_id))
def eval_nli(hparam, nli_setting, save_dir, data, model_path, load_fn): print("Train nil :", save_dir) task = transformer_nli_pooled(hparam, nli_setting.vocab_size) train_batches, dev_batches = data sess = init_session() sess.run(tf.global_variables_initializer()) load_fn(sess, model_path) def batch2feed_dict(batch): x0, x1, x2, y = batch feed_dict = { task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, task.y: y, } return feed_dict global_step = tf.train.get_or_create_global_step() valid_fn = partial(valid_fn_factory, sess, dev_batches, task.loss, task.acc, global_step, batch2feed_dict) return valid_fn()
def train_nli_multi_gpu(hparam, nli_setting, save_dir, num_steps, data, model_path, load_fn, n_gpu): print("Train nil :", save_dir) model_init_fn = lambda: transformer_nli_pooled(hparam, nli_setting. vocab_size) models = get_multiple_models(model_init_fn, n_gpu) losses = [model.loss for model in models] gradients = get_averaged_gradients(losses) avg_loss = get_avg_loss(models) avg_acc = get_avg_tensors_from_models(models, lambda model: model.acc) with tf.variable_scope("optimizer"): with tf.device("/device:CPU:0"): train_cls = get_train_op_from_grads_and_tvars( gradients, tf.trainable_variables(), hparam.lr, "adam", num_steps) print("Number of parameter : ", get_param_num()) train_batches, dev_batches = data sess = init_session() sess.run(tf.global_variables_initializer()) if model_path is not None: load_fn(sess, model_path) def batch2feed_dict(batch): x0, x1, x2, y = batch batch_size = len(x0) batch_size_per_gpu = int(batch_size / n_gpu) feed_dict = {} for gpu_idx in range(n_gpu): st = batch_size_per_gpu * gpu_idx ed = batch_size_per_gpu * (gpu_idx + 1) feed_dict[models[gpu_idx].x_list[0]] = x0[st:ed] feed_dict[models[gpu_idx].x_list[1]] = x1[st:ed] feed_dict[models[gpu_idx].x_list[2]] = x2[st:ed] feed_dict[models[gpu_idx].y] = y[st:ed] return feed_dict global_step = tf.train.get_or_create_global_step() train_classification = partial(train_classification_factory, sess, avg_loss, avg_acc, train_cls, batch2feed_dict) valid_fn = partial(valid_fn_factory, sess, dev_batches[:100], avg_loss, avg_acc, global_step, batch2feed_dict) def save_fn(): return save_model_to_dir_path(sess, save_dir, global_step) init_step, = sess.run([global_step]) print("Initialize step to {}".format(init_step)) print("{} train batches".format(len(train_batches))) valid_freq = 5000 save_interval = 7200 loss, _ = step_runner(train_batches, train_classification, init_step, valid_fn, valid_freq, save_fn, save_interval, num_steps) return save_fn()
def label_predict(hparam, data, model_path) -> List[np.array]: tprint("building model") voca_size = 30522 task = transformer_logit(hparam, 2, voca_size, False) enc_payload: List[Tuple[np.array, np.array, np.array]] = data sout = tf.nn.softmax(task.logits, axis=-1) sess = init_session() sess.run(tf.global_variables_initializer()) tprint("loading model") load_model(sess, model_path) def forward_run(inputs): batches = get_batches_ex(inputs, hparam.batch_size, 3) logit_list = [] ticker = TimeEstimator(len(batches)) for batch in batches: x0, x1, x2 = batch soft_out, = sess.run([ sout, ], feed_dict={ task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, }) logit_list.append(soft_out) ticker.tick() return np.concatenate(logit_list) logits = forward_run(enc_payload) return logits
def __init__(self, hparam, voca_size, start_model_path): print("AgreePredictor") tf.reset_default_graph() self.task = transformer_weight(hparam, voca_size, False) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) load_model_w_scope(self.sess, start_model_path, ['bert', 'cls_dense'])
def __init__(self, hparam, voca_size, start_model_path): print("run_ukp_ex") tf.reset_default_graph() self.task = transformer_nli(hparam, voca_size, 5, True) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) load_model_w_scope(self.sess, start_model_path, ['bert', 'cls_dense', 'aux_conflict'])
def __init__(self, model_path): self.voca_size = 30522 self.hp = hyperparams.HPFAD() self.model_dir = cpath.model_path self.task = transformer_next_sent(self.hp, 2, self.voca_size, False) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) self.load_model_white(model_path) self.batch_size = 64
def train_LMS(bert_hp, train_config, lms_config: LMSConfigI, save_dir, nli_data, modeling_option, init_fn): tf_logging.info("train_pairing ENTRY") train_batches, dev_batches = nli_data max_steps = train_config.max_steps num_gpu = train_config.num_gpu lms_model = LMSModel(modeling_option, bert_hp, lms_config, num_gpu) train_cls = lms_model.get_train_op(bert_hp.lr, max_steps) global_step = tf.train.get_or_create_global_step() run_name = os.path.basename(save_dir) train_writer, test_writer = setup_summary_writer(run_name) sess = init_session() sess.run(tf.global_variables_initializer()) init_fn(sess) # make explain train_op does not increase global step def fetch_global_step(): step, = sess.run([global_step]) return step train_classification = partial(train_fn_factory, sess, lms_model.loss_tensor, lms_model.per_layer_loss, train_cls, lms_model.batch2feed_dict) eval_acc = partial(eval_fn_factory, sess, dev_batches[:20], lms_model.loss_tensor, lms_model.per_layer_loss, lms_model.ex_score_tensor, lms_model.per_layer_logit_tensor, global_step, lms_model.batch2feed_dict, test_writer) save_fn = partial(save_fn_factory, sess, save_dir, global_step) init_step, = sess.run([global_step]) def train_fn(batch, step_i): loss_val, acc = train_classification(batch, step_i) summary = tf.Summary() summary.value.add(tag='loss', simple_value=loss_val) train_writer.add_summary(summary, fetch_global_step()) train_writer.flush() return loss_val, acc def valid_fn(): eval_acc() tf_logging.info("Initialize step to {}".format(init_step)) tf_logging.info("{} train batches".format(len(train_batches))) valid_freq = 100 save_interval = 300 loss, _ = step_runner(train_batches, train_fn, init_step, valid_fn, valid_freq, save_fn, save_interval, max_steps) return save_fn()
def eval_fidelity_gradient(hparam, nli_setting, flat_dev_batches, explain_tag, method_name, model_path): from attribution.gradient import explain_by_gradient from attribution.deepexplain.tensorflow import DeepExplain sess = init_session() with DeepExplain(session=sess, graph=sess.graph) as de: task = transformer_nli_pooled_embedding_in(hparam, nli_setting.vocab_size, False) softmax_out = tf.nn.softmax(task.logits, axis=-1) sess.run(tf.global_variables_initializer()) load_model(sess, model_path) emb_outputs = task.encoded_embedding_out, task.attention_mask_out emb_input = task.encoded_embedding_in, task.attention_mask_in def feed_end_input(batch): x0, x1, x2 = batch return { task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, } def forward_runs(insts): alt_batches = get_batches_ex(insts, hparam.batch_size, 3) alt_logits = [] for batch in alt_batches: enc, att = sess.run(emb_outputs, feed_dict=feed_end_input(batch)) logits, = sess.run([ softmax_out, ], feed_dict={ task.encoded_embedding_in: enc, task.attention_mask_in: att }) alt_logits.append(logits) alt_logits = np.concatenate(alt_logits) return alt_logits contrib_score = explain_by_gradient(flat_dev_batches, method_name, explain_tag, sess, de, feed_end_input, emb_outputs, emb_input, softmax_out) print("contrib_score", len(contrib_score)) print("flat_dev_batches", len(flat_dev_batches)) acc_list = eval_fidelity(contrib_score, flat_dev_batches, forward_runs, explain_tag) return acc_list
def __init__(self, model_path, num_classes): self.voca_size = 30522 load_names = ['bert', "output_bias", "output_weights"] self.hp = hyperparams.HPFAD() self.model_dir = cpath.model_path self.task = transformer_logit(self.hp, num_classes, self.voca_size, False) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) self.load_model_white(model_path, load_names) self.batch_size = 64
def baseline_predict(hparam, nli_setting, data, method_name, model_path) -> List[np.array]: tprint("building model") voca_size = 30522 task = transformer_logit(hparam, 2, voca_size, False) enc_payload: List[Tuple[np.array, np.array, np.array]] = data sout = tf.nn.softmax(task.logits, axis=-1) sess = init_session() sess.run(tf.global_variables_initializer()) tprint("loading model") load_model(sess, model_path) def forward_run(inputs): batches = get_batches_ex(inputs, hparam.batch_size, 3) logit_list = [] ticker = TimeEstimator(len(batches)) for batch in batches: x0, x1, x2 = batch soft_out, = sess.run([ sout, ], feed_dict={ task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, }) logit_list.append(soft_out) ticker.tick() return np.concatenate(logit_list) # train_batches, dev_batches = self.load_nli_data(data_loader) def idf_explain(enc_payload, explain_tag, forward_run): train_batches, dev_batches = get_nli_data(hparam, nli_setting) idf_scorer = IdfScorer(train_batches) return idf_scorer.explain(enc_payload, explain_tag, forward_run) todo_list = [ ('deletion_seq', explain_by_seq_deletion), ('replace_token', explain_by_replace), ('term_deletion', explain_by_term_deletion), ('term_replace', explain_by_term_replace), ('random', explain_by_random), ('idf', idf_explain), ('deletion', explain_by_deletion), ('LIME', explain_by_lime), ] method_dict = dict(todo_list) method = method_dict[method_name] explain_tag = "mismatch" explains: List[np.array] = method(enc_payload, explain_tag, forward_run) # pred_list = predict_translate(explains, data_loader, enc_payload, plain_payload) return explains
def nli_baseline_predict(hparam, nli_setting, data_loader, explain_tag, method_name, data_id, sub_range, model_path): enc_payload, plain_payload = data_loader.load_plain_text(data_id) assert enc_payload is not None assert plain_payload is not None name_format = "pred_{}_" + data_id if sub_range is not None: st, ed = [int(t) for t in sub_range.split(",")] enc_payload = enc_payload[st:ed] plain_payload = plain_payload[st:ed] name_format = "pred_{}_" + data_id + "__{}_{}".format(st, ed) print(name_format) task = transformer_nli_pooled(hparam, nli_setting.vocab_size) sout = tf.nn.softmax(task.logits, axis=-1) sess = init_session() sess.run(tf.global_variables_initializer()) load_model(sess, model_path) def forward_run(inputs): batches = get_batches_ex(inputs, hparam.batch_size, 3) logit_list = [] for batch in batches: x0, x1, x2 = batch soft_out, = sess.run([ sout, ], feed_dict={ task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, }) logit_list.append(soft_out) return np.concatenate(logit_list) # train_batches, dev_batches = self.load_nli_data(data_loader) def idf_explain(enc_payload, explain_tag, forward_run): train_batches, dev_batches = get_nli_data(hparam, nli_setting) idf_scorer = IdfScorer(train_batches) return idf_scorer.explain(enc_payload, explain_tag, forward_run) todo_list = [ ('deletion_seq', explain_by_seq_deletion), ('random', explain_by_random), ('idf', idf_explain), ('deletion', explain_by_deletion), ('LIME', explain_by_lime), ] method_dict = dict(todo_list) method = method_dict[method_name] explains = method(enc_payload, explain_tag, forward_run) pred_list = predict_translate(explains, data_loader, enc_payload, plain_payload) save_to_pickle(pred_list, name_format.format(method_name))
def __init__(self, hparam, nli_setting, model_path, modeling_option, tags_list=nli_info.tags): self.num_tags = len(tags_list) self.tags = tags_list self.define_graph(hparam, nli_setting, modeling_option) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) self.batch_size = hparam.batch_size load_model(self.sess, model_path)
def __init__(self, hparam, nli_setting, model_path, method_name): self.task = transformer_nli_pooled(hparam, nli_setting.vocab_size, False) self.sout = tf.nn.softmax(self.task.logits, axis=-1) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) self.batch_size = hparam.batch_size load_model(self.sess, model_path) todo_list = [ ('deletion_seq', explain_by_seq_deletion), ('random', explain_by_random), ('deletion', explain_by_deletion), ('LIME', explain_by_lime), ] method_dict = dict(todo_list) self.method = method_dict[method_name]
def train_nli(hparam, nli_setting, save_dir, max_steps, data, model_path, load_fn): print("Train nil :", save_dir) task = transformer_nli_pooled(hparam, nli_setting.vocab_size) with tf.variable_scope("optimizer"): train_cls = get_train_op2(task.loss, hparam.lr, "adam", max_steps) train_batches, dev_batches = data sess = init_session() sess.run(tf.global_variables_initializer()) if model_path is not None: load_fn(sess, model_path) def batch2feed_dict(batch): x0, x1, x2, y = batch feed_dict = { task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, task.y: y, } return feed_dict train_classification = partial(train_classification_factory, sess, task.loss, task.acc, train_cls, batch2feed_dict) global_step = tf.train.get_or_create_global_step() valid_fn = partial(valid_fn_factory, sess, dev_batches[:100], task.loss, task.acc, global_step, batch2feed_dict) save_fn = partial(save_fn_factory, sess, save_dir, global_step) init_step, = sess.run([global_step]) print("Initialize step to {}".format(init_step)) print("{} train batches".format(len(train_batches))) valid_freq = 5000 save_interval = 10000 loss, _ = step_runner(train_batches, train_classification, init_step, valid_fn, valid_freq, save_fn, save_interval, max_steps) return save_fn()
def __init__(self, topic, cheat = False, cheat_topic=None): self.voca_size = 30522 self.topic = topic load_names = ['bert', "cls_dense"] if not cheat: run_name = "arg_key_neccesary_{}".format(topic) else: run_name = "arg_key_neccesary_{}".format(cheat_topic) self.hp = hyperparams.HPBert() self.model_dir = cpath.model_path self.data_loader = BertDataLoader(topic, True, self.hp.seq_max, "bert_voca.txt") self.task = transformer_nli(self.hp, self.voca_size, 0, False) self.sess = init_session() self.sess.run(tf.global_variables_initializer()) self.merged = tf.summary.merge_all() self.load_model_white(run_name, load_names) self.batch_size = 512
def do_predict( bert_hp, train_config, batches, lms_config, modeling_option, init_fn, ): num_gpu = train_config.num_gpu lms_model = LMSModel(modeling_option, bert_hp, lms_config, num_gpu) sess = init_session() sess.run(tf.global_variables_initializer()) init_fn(sess) # make explain train_op does not increase global step output_d = predict_fn(sess, batches, lms_model.logits, lms_model.loss_tensor, lms_model.ex_score_tensor, lms_model.per_layer_logit_tensor, lms_model.batch2feed_dict) return output_d
def get_predictor(): dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") cnn = CNN("agree", sequence_length=FLAGS.comment_length, num_classes=3, filter_sizes=[1, 2, 3], num_filters=64, init_emb=load_local_pickle("init_embedding"), embedding_size=FLAGS.embedding_size, dropout_prob=dropout_keep_prob) input_comment = tf.placeholder(tf.int32, shape=[None, FLAGS.comment_length], name="comment_input") #sout = model.cnn.network(input_comment) sout = cnn.network(input_comment) sess = init_session() batch_size = 512 path = os.path.join(model_path, "runs", "agree", "model-36570") variables = tf.contrib.slim.get_variables_to_restore() for v in variables: print(v.name) loader = tf.train.Saver(variables) loader.restore(sess, path) def predict(comments): batches = get_batches_ex(comments, batch_size, 1) all_scores = [] ticker = TimeEstimator(len(batches)) for batch in batches: scores, = sess.run([sout], feed_dict={ input_comment: batch[0], dropout_keep_prob: 1.0, }) all_scores.append(scores) ticker.tick() return np.concatenate(all_scores) return predict
def predict(hparam, run_name, dev_batches, model_path, load_fn): print("predict :", run_name) task = transformer_logit(hparam, 2, hparam.vocab_size, False) sess = init_session() sess.run(tf.global_variables_initializer()) if model_path is not None: load_fn(sess, model_path) def batch2feed_dict(batch): x0, x1, x2, y = batch feed_dict = { task.x_list[0]: x0, task.x_list[1]: x1, task.x_list[2]: x2, task.y: y, } return feed_dict label_list = [] pred_list = [] for batch in dev_batches: logits, = sess.run([task.logits, ], feed_dict=batch2feed_dict(batch) ) _, _, _, y = batch label_list.append(y) pred_list.append(np.argmax(logits, axis=-1)) pred_list = np.concatenate(pred_list, axis=0) label_list = np.concatenate(label_list, axis=0) if len(pred_list) != len(label_list): print("WARNING , data size is different : ", len(pred_list), len(label_list)) output = pred_list, label_list out_path = os.path.join(output_path, "msmarco_" + run_name.replace("/", "_")) pickle.dump(output, open(out_path, "wb"))
def run(): all_loss = 0 tower_grads = [] input_x_list = [] input_y_list = [] models = [] for gpu_idx in range(2): with tf.device("/gpu:{}".format(gpu_idx)): with tf.variable_scope("vars", reuse=gpu_idx > 0): input_x = placeholder(tf.float32, [None, 10]) input_y = placeholder(tf.int32, [ None, ]) input_x_list.append(input_x) input_y_list.append(input_y) model = FF(input_x, input_y) models.append(model) tf.get_variable_scope().reuse_variables() all_loss += model.task.loss tvars = tf.trainable_variables() for t in tvars: print(t.name) for gpu_idx in range(2): grads = tf.gradients(model.task.loss, tvars) print(grads) # Keep track of the gradients across all towers. tower_grads.append(grads) avg_grads = [] for t_idx, _ in enumerate(tvars): g1 = tower_grads[0][0] g2 = tower_grads[1][1] g_avg = (g1 + g2) / 2 if g1 is not None else None avg_grads.append(g_avg) global_step = tf.Variable(0, name='global_step') optimizer = AdamWeightDecayOptimizer( learning_rate=0.001, weight_decay_rate=0.02, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) train_cls = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step) #train_cls = get_train_op2(all_loss, 0.001, "adam", 10000) sess = init_session() sess.run(tf.global_variables_initializer()) def train_classification(i): if i % 2 == 0: random_input = np.ones([ batch_size, ], ) else: random_input = np.zeros([ batch_size, ]) random_input = np.ones([ batch_size, ], ) loss_val, _ = sess.run( [model.task.loss, train_cls], feed_dict={ input_x_list[0]: np.ones([batch_size, 10]), input_x_list[1]: np.ones([batch_size, 10]), input_y_list[0]: np.zeros([ batch_size, ]), input_y_list[1]: random_input, }) print(loss_val) for i in range(10): print("Train") train_classification(i)
def train_nli_ex_from_payload( hparam, train_config, save_dir, data, data_loader, tags, modeling_option, init_fn, ): print("train_nli_ex") max_steps = train_config.max_steps num_gpu = train_config.num_gpu base_step = 61358 def load_payload(step_i): save_path = os.path.join(cpath.data_path, "nli_payload_1", str(step_i)) return pickle.load(open(save_path, "rb")) ex_modeling_class = { 'ce': CrossEntropyModeling, 'co': CorrelationModeling }[modeling_option] lr_factor = 0.3 def build_model(): main_model = transformer_nli_pooled(hparam, train_config.vocab_size) ex_model = ex_modeling_class(main_model.model.sequence_output, hparam.seq_max, len(tags), main_model.batch2feed_dict) return main_model, ex_model if num_gpu == 1: print("Using single GPU") main_model, ex_model = build_model() batch2feed_dict = main_model.batch2feed_dict ex_score_tensor = ex_model.get_scores() ex_loss_tensor = ex_model.get_loss() ex_per_tag_loss = ex_model.get_losses() ex_batch_feed2dict = ex_model.batch2feed_dict with tf.variable_scope("explain_optimizer"): train_ex_op = get_train_op2(ex_loss_tensor, hparam.lr * lr_factor, "adam2", max_steps) else: main_models, ex_models = zip( *get_multiple_models(build_model, num_gpu)) batch2feed_dict = get_batch2feed_dict_for_multi_gpu(main_models) def get_loss_tensor(model): t = tf.expand_dims(tf.stack(model.get_losses()), 0) return t ex_per_tag_loss = tf.reduce_mean(get_concat_tensors_from_models( ex_models, get_loss_tensor), axis=0) ex_score_tensor = get_concat_tensors_list_from_models( ex_models, lambda model: model.get_scores()) print(ex_score_tensor) ex_loss_tensor = get_avg_tensors_from_models( ex_models, ex_modeling_class.get_loss) ex_batch_feed2dict = get_batch2feed_dict_for_multi_gpu(ex_models) with tf.variable_scope("explain_optimizer"): train_ex_op = get_train_op(([m.get_loss() for m in ex_models]), hparam.lr * lr_factor, max_steps) global_step = tf.train.get_or_create_global_step() explain_dev_data_list = { tag: data_loader.get_dev_explain(tag) for tag in tags } run_name = os.path.basename(save_dir) train_writer, test_writer = setup_summary_writer(run_name) sess = init_session() sess.run(tf.global_variables_initializer()) init_fn(sess) def eval_tag(): print("Eval") for label_idx, tag in enumerate(tags): print(tag) enc_explain_dev, explain_dev = explain_dev_data_list[tag] batches = get_batches_ex(enc_explain_dev, hparam.batch_size, 3) try: ex_logit_list = [] for batch in batches: ex_logits, = sess.run([ ex_score_tensor[label_idx], ], feed_dict=batch2feed_dict(batch)) print(ex_logits.shape) ex_logit_list.append(ex_logits) ex_logit_list = np.concatenate(ex_logit_list, axis=0) print(ex_logit_list.shape) assert len(ex_logit_list) == len(explain_dev) scores = eval_explain(ex_logit_list, data_loader, tag) for metric in scores.keys(): print("{}\t{}".format(metric, scores[metric])) p_at_1, MAP_score = scores["P@1"], scores["MAP"] summary = tf.Summary() summary.value.add(tag='{}_P@1'.format(tag), simple_value=p_at_1) summary.value.add(tag='{}_MAP'.format(tag), simple_value=MAP_score) train_writer.add_summary(summary, fetch_global_step()) train_writer.flush() except ValueError as e: print(e) for ex_logits in ex_logit_list: print(ex_logits.shape) # make explain train_op does not increase global step def train_explain(step_i): def commit_ex_train(batch): fd = ex_batch_feed2dict(batch) ex_loss, _ = sess.run([ex_per_tag_loss, train_ex_op], feed_dict=fd) return ex_loss batch = load_payload(step_i) commit_ex_train(batch) def fetch_global_step(): step, = sess.run([global_step]) return step save_fn = partial(save_fn_factory, sess, save_dir, global_step) sess.run([global_step.assign(base_step)]) g_step_check, = sess.run([global_step]) print("Initialize step to {}".format(g_step_check)) def train_fn(step_i): step_before_cls = fetch_global_step() train_explain(step_i) step_after_ex = fetch_global_step() assert step_after_ex == step_before_cls + 1 return 0, 0 def valid_fn(): eval_tag() valid_freq = 1000 save_interval = 5000 step_i = 61358 while step_i < max_steps: print(step_i) step_i += 1 train_fn(step_i) if valid_fn is not None: if step_i % valid_freq == 0: valid_fn() return save_fn()
def train_self_explain( hparam, train_config, save_dir, data, data_loader, tags, modeling_option, init_fn, tag_informative_fn, ): print("train_self_explain") max_steps = train_config.max_steps num_gpu = train_config.num_gpu train_batches, dev_batches = data def save_payload(payload, step_i): save_path = os.path.join(cpath.data_path, "nli_payload_1", str(step_i)) pickle.dump(payload, open(save_path, "wb")) ex_modeling_class = { 'ce': CrossEntropyModeling, 'co': CorrelationModeling }[modeling_option] lr_factor = 0.3 def build_model(): main_model = transformer_pooled(hparam, train_config.vocab_size) ex_model = ex_modeling_class(main_model.model.sequence_output, hparam.seq_max, len(tags), main_model.batch2feed_dict) return main_model, ex_model if num_gpu == 1: print("Using single GPU") main_model, ex_model = build_model() loss_tensor = main_model.loss acc_tensor = main_model.acc with tf.variable_scope("optimizer"): train_cls = get_train_op2(main_model.loss, hparam.lr, "adam", max_steps) batch2feed_dict = main_model.batch2feed_dict logits = main_model.logits ex_score_tensor = ex_model.get_scores() ex_loss_tensor = ex_model.get_loss() ex_per_tag_loss = ex_model.get_losses() ex_batch_feed2dict = ex_model.batch2feed_dict with tf.variable_scope("explain_optimizer"): train_ex_op = get_train_op_wo_gstep_update2( ex_loss_tensor, hparam.lr * lr_factor, "adam2", max_steps) else: main_models, ex_models = zip( *get_multiple_models(build_model, num_gpu)) loss_tensor = get_avg_loss(main_models) acc_tensor = get_avg_tensors_from_models(main_models, lambda model: model.acc) with tf.variable_scope("optimizer"): train_cls = get_train_op([m.loss for m in main_models], hparam.lr, max_steps) batch2feed_dict = get_batch2feed_dict_for_multi_gpu(main_models) logits = get_concat_tensors_from_models(main_models, lambda model: model.logits) def get_loss_tensor(model): t = tf.expand_dims(tf.stack(model.get_losses()), 0) return t ex_per_tag_loss = tf.reduce_mean(get_concat_tensors_from_models( ex_models, get_loss_tensor), axis=0) ex_score_tensor = get_concat_tensors_list_from_models( ex_models, lambda model: model.get_scores()) print(ex_score_tensor) ex_loss_tensor = get_avg_tensors_from_models( ex_models, ex_modeling_class.get_loss) ex_batch_feed2dict = get_batch2feed_dict_for_multi_gpu(ex_models) with tf.variable_scope("explain_optimizer"): train_ex_op = get_other_train_op_multi_gpu( ([m.get_loss() for m in ex_models]), hparam.lr * lr_factor, max_steps) global_step = tf.train.get_or_create_global_step() if data_loader is not None: explain_dev_data_list = { tag: data_loader.get_dev_explain(tag) for tag in tags } run_name = os.path.basename(save_dir) train_writer, test_writer = setup_summary_writer(run_name) information_fn_list = list([partial(tag_informative_fn, t) for t in tags]) def forward_run(batch): result, = sess.run([logits], feed_dict=batch2feed_dict(batch)) return result explain_trainer = ExplainTrainerM( information_fn_list, len(tags), action_to_label=ex_modeling_class.action_to_label, get_null_label=ex_modeling_class.get_null_label, forward_run=forward_run, batch_size=hparam.batch_size, num_deletion=train_config.num_deletion, g_val=train_config.g_val, drop_thres=train_config.drop_thres) sess = init_session() sess.run(tf.global_variables_initializer()) init_fn(sess) def eval_tag(): if data_loader is None: return print("Eval") for label_idx, tag in enumerate(tags): print(tag) enc_explain_dev, explain_dev = explain_dev_data_list[tag] batches = get_batches_ex(enc_explain_dev, hparam.batch_size, 3) try: ex_logit_list = [] for batch in batches: ex_logits, = sess.run([ ex_score_tensor[label_idx], ], feed_dict=batch2feed_dict(batch)) print(ex_logits.shape) ex_logit_list.append(ex_logits) ex_logit_list = np.concatenate(ex_logit_list, axis=0) print(ex_logit_list.shape) assert len(ex_logit_list) == len(explain_dev) scores = eval_explain(ex_logit_list, data_loader, tag) for metric in scores.keys(): print("{}\t{}".format(metric, scores[metric])) p_at_1, MAP_score = scores["P@1"], scores["MAP"] summary = tf.Summary() summary.value.add(tag='{}_P@1'.format(tag), simple_value=p_at_1) summary.value.add(tag='{}_MAP'.format(tag), simple_value=MAP_score) train_writer.add_summary(summary, fetch_global_step()) train_writer.flush() except ValueError as e: print(e) for ex_logits in ex_logit_list: print(ex_logits.shape) # make explain train_op does not increase global step def train_explain(batch, step_i): def commit_ex_train(batch): if train_config.save_train_payload: save_payload(batch, step_i) fd = ex_batch_feed2dict(batch) ex_loss, _ = sess.run([ex_per_tag_loss, train_ex_op], feed_dict=fd) return ex_loss summary = explain_trainer.train_batch(batch, commit_ex_train) train_writer.add_summary(summary, fetch_global_step()) def fetch_global_step(): step, = sess.run([global_step]) return step train_classification = partial(train_classification_factory, sess, loss_tensor, acc_tensor, train_cls, batch2feed_dict) eval_acc = partial(valid_fn_factory, sess, dev_batches[:20], loss_tensor, acc_tensor, global_step, batch2feed_dict) save_fn = partial(save_fn_factory, sess, save_dir, global_step) init_step, = sess.run([global_step]) def train_fn(batch, step_i): step_before_cls = fetch_global_step() loss_val, acc = train_classification(batch, step_i) summary = tf.Summary() summary.value.add(tag='acc', simple_value=acc) summary.value.add(tag='loss', simple_value=loss_val) train_writer.add_summary(summary, fetch_global_step()) train_writer.flush() tf_logging.debug("{}".format(step_i)) step_after_cls = fetch_global_step() assert step_after_cls == step_before_cls + 1 train_explain(batch, step_i) step_after_ex = fetch_global_step() assert step_after_cls == step_after_ex return loss_val, acc def valid_fn(): eval_acc() eval_tag() print("Initialize step to {}".format(init_step)) print("{} train batches".format(len(train_batches))) valid_freq = 1000 save_interval = 5000 loss, _ = step_runner(train_batches, train_fn, init_step, valid_fn, valid_freq, save_fn, save_interval, max_steps) return save_fn()