def run_batch(self, sess, inps, istrn=True): plhs = [self.inp_l_plh, self.tgt_l_plh, self.msk_l_plh, self.y_plh, self.inp_u_plh, self.tgt_u_plh, self.msk_u_plh,] if istrn: fetch_dict = [['elbo_l', self.elbo_loss_l], ['ppl_l', self.ppl_l], ['kl_l', self.kl_loss_l], ['pred_l', self.predict_loss_l], ['loss_u', self.loss_u], ['train_u', self.train_unlabel], ['klw', self.kl_w]] feed_dict = dict(list(zip(plhs, inps)) + [[self.is_training, True]]) fetch = [self.merged] + [t[1] for t in fetch_dict] + [self.train_op] res = sess.run(fetch, feed_dict) res_dict = dict([[fetch_dict[i][0], res[i+1]] for i in range(len(fetch_dict))]) res_str = res_to_string(res_dict) else: fetch_dict = [['pred_l', self.predict_loss_l],] feed_dict = dict(list(zip(plhs, inps+inps[:-1])) + [[self.is_training, False]]) fetch = [self.merged] + [t[1] for t in fetch_dict] res = sess.run(fetch, feed_dict) res_dict = dict([[fetch_dict[i][0], res[i+1]] for i in range(len(fetch_dict))]) res_str = res_to_string(res_dict) return res_dict, res_str, res[0]
def validate(valid_dset, model, sess): res_list = [] #threaded_it = threaded_generator(valid_dset, 200) for batch in valid_dset: res_dict, res_str, summary = model.run_batch(sess, batch, istrn=False) res_list.append(res_dict) out_str = res_to_string(average_res(res_list)) return out_str
def run_batch(self, sess, inps, istrn=True): plhs = [self.input_plh, self.mask_plh, self.label_plh] if istrn: fetch_dict = [ ['loss_total', self.loss_total], ['loss', self.loss], ['reg_l1', self.loss_reg_l1], ['reg_diff', self.loss_reg_diff], ['reg_sharp', self.loss_reg_sharp], ['reg_frobenius', self.loss_reg_frobenius], ['acc', self.accuracy], ] feed_dict = dict( list(zip(plhs, inps)) + [[self.is_training_plh, istrn]]) fetch_nonscalar = [self.merged, self.gate_weights, self.pred] fetch = fetch_nonscalar + [t[1] for t in fetch_dict] + [self.train_op] res = sess.run(fetch, feed_dict) res_dict = dict([[fetch_dict[i][0], res[i + len(fetch_nonscalar)]] for i in range(len(fetch_dict))]) res_str = res_to_string(res_dict) else: fetch_dict = [ ['loss_total', self.loss_total], ['loss', self.loss], ['loss_reg_l1', self.loss_reg_l1], ['loss_reg_diff', self.loss_reg_diff], ['loss_reg_sharp', self.loss_reg_sharp], ['acc', self.accuracy], ] feed_dict = dict( list(zip(plhs, inps)) + [[self.is_training_plh, istrn]]) fetch_nonscalar = [self.merged, self.gate_weights, self.pred] fetch = fetch_nonscalar + [t[1] for t in fetch_dict] res = sess.run(fetch, feed_dict) res_dict = dict([[fetch_dict[i][0], res[i + len(fetch_nonscalar)]] for i in range(len(fetch_dict))]) res_str = res_to_string(res_dict) return res_dict, res_str, res[0], res[1:len(fetch_nonscalar)]
def run(args, model, sess, label_dset, unlabel_dset, valid_dset, test_dset, explogger): batch_cnt = 0 res_list = [] # init tensorboard writer tb_writer = tf.summary.FileWriter(args.save_dir + '/train', sess.graph) t_time = time.time() time_cnt = 0 for epoch_idx in range(args.max_epoch): #threaded_it_u = threaded_generator(unlabel_dset, 200) #threaded_it_l = threaded_generator(label_dset, 200) for batch_u in unlabel_dset: batch_cnt += 1 batch_l = get_batch(label_dset) gen_time = time.time() - t_time t_time = time.time() #res_dict, res_str, summary = model.run_batch(sess, batch_l+batch_u[:-args.num_tasks], istrn=True) res_dict, res_str, summary = model.run_batch(sess, batch_l + batch_u, istrn=True, args=args) run_time = time.time() - t_time res_dict.update({'run_time': run_time}) res_dict.update({'gen_time': gen_time}) res_list.append(res_dict) res_list = res_list[-200:] time_cnt += gen_time + run_time if batch_cnt % args.show_every == 0: tb_writer.add_summary(summary, batch_cnt) out_str = res_to_string(average_res(res_list)) explogger.message(out_str, True) if args.validate_every != -1 and batch_cnt % args.validate_every == 0: out_str = 'VALIDATE:' + validate(valid_dset, model, sess, args) explogger.message(out_str, True) if args.validate_every != -1 and batch_cnt % args.validate_every == 0: out_str = 'TEST:' + validate(test_dset, model, sess, args) explogger.message(out_str, True) if args.save_every != -1 and batch_cnt % args.save_every == 0: save_fn = os.path.join(args.save_dir, args.log_prefix) explogger.message("Saving checkpoint model: {} ......".format( args.save_dir)) model.saver.save(sess, save_fn, write_meta_graph=False, global_step=batch_cnt) t_time = time.time()
def train_and_validate(args, model, sess, train_dset, valid_dset, test_dset, explogger, vocab, class_map): batch_cnt = 0 res_list = [] # init tensorboard writer tb_writer = tf.summary.FileWriter(args.save_dir + '/train', sess.graph) t_time = time.time() time_cnt = 0 for epoch_idx in range(args.max_epoch): threaded_it = threaded_generator(train_dset, 200) for batch in threaded_it: batch_cnt += 1 gen_time = time.time() - t_time t_time = time.time() res_dict, res_str, summary, gate_weights = model.run_batch( sess, batch, istrn=True) run_time = time.time() - t_time res_dict.update({'run_time': run_time}) res_dict.update({'gen_time': gen_time}) res_list.append(res_dict) res_list = res_list[-200:] time_cnt += gen_time + run_time if batch_cnt % args.show_every == 0: tb_writer.add_summary(summary, batch_cnt) out_str = res_to_string(average_res(res_list)) explogger.message(out_str, True) if args.validate_every != -1 and batch_cnt % args.validate_every == 0: out_str = validate(valid_dset, model, sess, args, vocab, class_map) explogger.message('VALIDATE: ' + out_str, True) if args.validate_every != -1 and batch_cnt % args.validate_every == 0: out_str = validate(test_dset, model, sess, args, vocab, class_map) explogger.message('TEST: ' + out_str, True) if args.save_every != -1 and batch_cnt % args.save_every == 0: save_fn = os.path.join(args.save_dir, args.log_prefix) explogger.message("Saving checkpoint model: {} ......".format( args.save_dir)) model.saver.save(sess, save_fn, write_meta_graph=False, global_step=batch_cnt) t_time = time.time()
def validate(valid_dset, model, sess, args, vocab, class_map): res_list = [] threaded_it = threaded_generator(valid_dset, 200) wf = open(os.path.join(args.save_dir, 'gate.log'), 'w') rf = open(os.path.join(args.save_dir, 'predict.log'), 'w') rf.write('pred\ttarget\ttxt\n') list_predict = [] list_target = [] for batch in threaded_it: res_dict, res_str, summary, [gate_weights, pred] = model.run_batch(sess, batch, istrn=False) res_list.append(res_dict) batch_size = batch[0].shape[0] #wf.write(str(gate_weights)) #wf.write(str(gate_weights.shape)) for sidx in range(batch_size): for idx, w in enumerate(batch[0][sidx]): wf.write(vocab[w] + '\t') wf.write('{0:.3f}\t'.format(gate_weights[sidx][idx][0]) + '\n') wf.write('\n') pred = [class_map[idx] for idx in pred] tgt = [class_map[idx] for idx in batch[2]] list_predict.extend(pred) list_target.extend(tgt) inp = [[vocab[idx] for idx in s if idx != 0] for s in batch[0]] for i in range(batch_size): rf.write(str(pred[i]) + '\t') rf.write(str(tgt[i]) + '\t') rf.write(''.join(inp[i]) + '\n') wf.close() rf.close() cf = open(os.path.join(args.save_dir, 'confusion_matrix.log'), 'w') labels = sorted(list(class_map.values())) confusion = confusion_matrix(list_target, list_predict, labels) outstr = utils.confusion_matrix_to_string(confusion, labels) #print(outstr) cf.write(outstr) out_str = res_to_string(average_res(res_list)) return out_str
def run(args, model, sess, train_dset, valid_dset, test_dset, explogger): with open(args.vocab_path, 'rb') as f: vocab = pkl.load(f) vocab = {int(vocab[k]): k for k in vocab} batch_cnt = 0 res_list = [] # init tensorboard writer tb_writer = tf.summary.FileWriter(args.save_dir + '/train', sess.graph) t_time = time.time() time_cnt = 0 for epoch_idx in range(args.max_epoch): #threaded_it_u = threaded_generator(unlabel_dset, 200) #threaded_it_l = threaded_generator(label_dset, 200) for batch in train_dset: batch_cnt += 1 gen_time = time.time() - t_time t_time = time.time() #res_dict, res_str, summary = model.run_batch(sess, batch_l+batch_u[:-1], istrn=True) plhs = [model.input_plh, model.mask_plh, model.label_plh, model.is_training] gates = sess.run(model.weights, dict(zip(plhs, list(batch)+[True]))) print(gates) print(gates.shape) for sidx in range(20): for idx, w in enumerate(batch[0][sidx]): print(vocab[w] + '\t', end='') #for idx, w in enumerate(batch[1][sidx]): print('{0:.3f}\t'.format(gates[sidx][idx][0])) print() return run_time = time.time() - t_time res_dict.update({'run_time': run_time}) res_dict.update({'gen_time': gen_time}) res_list.append(res_dict) res_list = res_list[-200:] time_cnt += gen_time + run_time if batch_cnt % args.show_every == 0: tb_writer.add_summary(summary, batch_cnt) out_str = res_to_string(average_res(res_list)) explogger.message(out_str) if args.validate_every != -1 and batch_cnt % args.validate_every == 0: out_str = 'VALIDATE:' + validate(valid_dset, model, sess) explogger.message(out_str) if args.validate_every != -1 and batch_cnt % args.validate_every == 0: out_str = 'TEST:' + validate(test_dset, model, sess) explogger.message(out_str) if batch_cnt % args.save_every == 0: save_fn = os.path.join(args.save_dir, args.log_prefix) explogger.message("Saving checkpoint model: {} ......".format(args.save_dir)) model.saver.save(sess, save_fn, write_meta_graph=False, global_step=batch_cnt) t_time = time.time()