def __init__(self, checkpoint_path, num_classes, hyper_params): # Get ops from graph with tf.device("/gpu:0"): # Placeholders pl_sparse_points_centered_batched, _, _ = model.get_placeholders( hyper_params["num_point"], hyperparams=hyper_params) pl_is_training = tf.placeholder(tf.bool, shape=()) # Prediction pred, _ = model.get_model( pl_sparse_points_centered_batched, pl_is_training, num_classes, hyperparams=hyper_params, ) sparse_labels_batched = tf.argmax(pred, axis=2) # (1, num_sparse_points) -> (num_sparse_points,) sparse_labels = tf.reshape(sparse_labels_batched, [-1]) sparse_labels = tf.cast(sparse_labels, tf.int32) # Saver saver = tf.train.Saver() # Graph for interpolating labels # Assuming batch_size == 1 for simplicity pl_sparse_points_batched = tf.placeholder(tf.float32, (None, None, 3)) sparse_points = tf.reshape(pl_sparse_points_batched, [-1, 3]) pl_dense_points = tf.placeholder(tf.float32, (None, 3)) pl_knn = tf.placeholder(tf.int32, ()) dense_labels, dense_colors = interpolate_label_with_color( sparse_points, sparse_labels, pl_dense_points, pl_knn) self.ops = { "pl_sparse_points_centered_batched": pl_sparse_points_centered_batched, "pl_sparse_points_batched": pl_sparse_points_batched, "pl_dense_points": pl_dense_points, "pl_is_training": pl_is_training, "pl_knn": pl_knn, "dense_labels": dense_labels, "dense_colors": dense_colors, } # Restore checkpoint to session config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False self.sess = tf.Session(config=config) saver.restore(self.sess, checkpoint_path) print("Model restored")
def __init__(self, checkpoint_path, num_classes, hyper_params): # Get ops from graph with tf.device("/gpu:0"): # Placeholder pl_points, _, _ = model.get_placeholders(hyper_params["num_point"], hyperparams=hyper_params) pl_is_training = tf.placeholder(tf.bool, shape=()) print("pl_points shape", tf.shape(pl_points)) # Prediction pred, _ = model.get_model(pl_points, pl_is_training, num_classes, hyperparams=hyper_params) # Saver saver = tf.train.Saver() # Graph for interpolating labels # Assuming batch_size == 1 for simplicity pl_sparse_points = tf.placeholder(tf.float32, (None, 3)) pl_sparse_labels = tf.placeholder(tf.int32, (None, )) pl_dense_points = tf.placeholder(tf.float32, (None, 3)) pl_knn = tf.placeholder(tf.int32, ()) sparse_indices = interpolate_label(pl_sparse_points, pl_sparse_labels, pl_dense_points, pl_knn) self.ops = { "pl_points": pl_points, "pl_is_training": pl_is_training, "pred": pred, "pl_sparse_points": pl_sparse_points, "pl_sparse_labels": pl_sparse_labels, "pl_dense_points": pl_dense_points, "pl_knn": pl_knn, "sparse_indices": sparse_indices, } # Restore checkpoint to session config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False self.sess = tf.Session(config=config) saver.restore(self.sess, checkpoint_path) print("Model restored")
def test(FLAGS): # read data dataset = DataSet(fpath=FLAGS.test_file, seqlen=FLAGS.seq_len, n_classes=FLAGS.num_classes, num_feature=FLAGS.num_feature, is_raw=FLAGS.is_raw, need_shuffle=False) # set character set size FLAGS.charset_size = dataset.charset_size with tf.Graph().as_default(): # placeholder placeholders = get_placeholders(FLAGS) # get inference pred, layers = inference(placeholders['data'], FLAGS, for_training=False) prob = tf.nn.softmax(pred) # calculate prediction _hit_op = tf.equal(tf.argmax(pred, 1), tf.argmax(placeholders['labels'], 1)) hit_op = tf.reduce_sum(tf.cast(_hit_op, tf.float32)) # create saver saver = tf.train.Saver() # summary summary_op = tf.summary.merge_all() with tf.Session() as sess: # load model ckpt = tf.train.latest_checkpoint( os.path.dirname(FLAGS.checkpoint_path)) if tf.train.checkpoint_exists(ckpt): saver.restore(sess, ckpt) global_step = ckpt.split('/')[-1].split('-')[-1] logging( 'Succesfully loaded model from %s at step=%s.' % (ckpt, global_step), FLAGS) else: logging("[ERROR] Checkpoint not exist", FLAGS) return # iter batch hit_count = 0.0 total_count = 0 pred_list = [] label_list = [] logging("%s: starting test." % (datetime.now()), FLAGS) start_time = time.time() total_batch_size = math.ceil(dataset._num_data / FLAGS.batch_size) for step, (data, labels) in enumerate( dataset.iter_once(FLAGS.batch_size)): hits, pred_val = sess.run([hit_op, prob], feed_dict={ placeholders['data']: data, placeholders['labels']: labels }) hit_count += np.sum(hits) total_count += len(data) for i, p in enumerate(pred_val): pred_list.append(p[0]) label_list.append(labels[i][0]) if step % FLAGS.log_interval == 0: duration = time.time() - start_time sec_per_batch = duration / FLAGS.log_interval examples_per_sec = FLAGS.batch_size / sec_per_batch logging( '%s: [%d batches out of %d] (%.1f examples/sec; %.3f' 'sec/batch)' % (datetime.now(), step, total_batch_size, examples_per_sec, sec_per_batch), FLAGS) start_time = time.time() # micro precision # logging("%s: micro-precision = %.5f" % # (datetime.now(), (hit_count/total_count)), FLAGS) auc_val = roc_auc_score(label_list, pred_list) logging( "%s: micro-precision = %.5f, auc = %.5f" % (datetime.now(), (hit_count / total_count), auc_val), FLAGS) pred_y = [1 if i > FLAGS.threshold else 0 for i in pred_list] TN, FP, FN, TP = confusion_matrix(label_list, pred_y, labels=[1, 0]).ravel() Sensitivity = round((TP / (TP + FN)), 4) if TP + FN > 0 else 0 Specificity = round(TN / (FP + TN), 4) if FP + TN > 0 else 0 Precision = round(TP / (TP + FP), 4) if TP + FP > 0 else 0 Accuracy = round((TP + TN) / (TP + FP + TN + FN), 4) MCC = round( ((TP * TN) - (FP * FN)) / (math.sqrt( (TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))), 4 ) if TP + FP > 0 and FP + TN > 0 and TP + FN and TN + FN else 0 F1 = round((2 * TP) / ((2 * TP) + FP + FN), 4) fout = open(FLAGS.out_file, 'a') Prec = round(hit_count / total_count, 4) AUC = round(auc_val, 4) fout.write( f"{global_step},{datetime.now()},{TP},{FP},{TN},{FN},{Sensitivity},{Specificity},{Precision},{Accuracy},{MCC},{F1},{AUC}\n" ) logging( f"TP={TP}, FP={FP}, TN={TN}, FN={FN}, Sens={Sensitivity}, Spec={Specificity}, Prec={Precision}, Acc={Accuracy}, MCC={MCC}, F1={F1}, AUC={auc_val}", FLAGS) fout.close()
def train(): """Train the model on a single GPU """ with tf.Graph().as_default(): stacker, stack_validation, stack_train = init_stacking() with tf.device("/gpu:" + str(PARAMS["gpu"])): pointclouds_pl, labels_pl, smpws_pl = model.get_placeholders( PARAMS["num_point"], hyperparams=PARAMS) is_training_pl = tf.compat.v1.placeholder(tf.bool, shape=()) # Note the global_step=batch parameter to minimize. # That tells the optimizer to helpfully increment the 'batch' parameter for # you every time it trains. batch = tf.Variable(0) bn_decay = get_bn_decay(batch) tf.summary.scalar("bn_decay", bn_decay) print("--- Get model and loss") # Get model and loss pred, end_points = model.get_model( pointclouds_pl, is_training_pl, NUM_CLASSES, hyperparams=PARAMS, bn_decay=bn_decay, ) loss = model.get_loss(pred, labels_pl, smpws_pl, end_points) tf.summary.scalar("loss", loss) # Compute accuracy correct = tf.equal(tf.argmax(pred, 2), tf.compat.v1.to_int64(labels_pl)) accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float( PARAMS["batch_size"] * PARAMS["num_point"]) tf.summary.scalar("accuracy", accuracy) # Computer mean intersection over union mean_intersection_over_union, update_iou_op = tf.compat.v1.metrics.mean_iou( tf.compat.v1.to_int32(labels_pl), tf.compat.v1.to_int32(tf.argmax(pred, 2)), NUM_CLASSES) tf.summary.scalar( "mIoU", tf.compat.v1.to_float(mean_intersection_over_union)) print("--- Get training operator") # Get training operator learning_rate = get_learning_rate(batch) tf.summary.scalar("learning_rate", learning_rate) if PARAMS["optimizer"] == "momentum": optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=PARAMS["momentum"]) else: assert PARAMS["optimizer"] == "adam" optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate) train_op = optimizer.minimize(loss, global_step=batch) # Add ops to save and restore all the variables. saver = tf.compat.v1.train.Saver() # Create a session config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False sess = tf.compat.v1.Session(config=config) # Add summary writers merged = tf.compat.v1.summary.merge_all() train_writer = tf.compat.v1.summary.FileWriter( os.path.join(PARAMS["logdir"], "train"), sess.graph) validation_writer = tf.compat.v1.summary.FileWriter( os.path.join(PARAMS["logdir"], "validation"), sess.graph) # Init variables sess.run(tf.compat.v1.global_variables_initializer()) sess.run( tf.compat.v1.local_variables_initializer()) # important for mIoU ops = { "pointclouds_pl": pointclouds_pl, "labels_pl": labels_pl, "smpws_pl": smpws_pl, "is_training_pl": is_training_pl, "pred": pred, "loss": loss, "train_op": train_op, "merged": merged, "step": batch, "end_points": end_points, "update_iou": update_iou_op, } # Train for hyper_params["max_epoch"] epochs best_acc = 0 for epoch in range(PARAMS["max_epoch"]): print("in epoch", epoch) print("max_epoch", PARAMS["max_epoch"]) log_string("**** EPOCH %03d ****" % (epoch)) sys.stdout.flush() # Train one epoch train_one_epoch(sess, ops, train_writer, stack_train) # Evaluate, save, and compute the accuracy if epoch % 5 == 0: acc = eval_one_epoch(sess, ops, validation_writer, stack_validation) if acc > best_acc: best_acc = acc save_path = saver.save( sess, os.path.join(PARAMS["logdir"], "best_model_epoch_%03d.ckpt" % (epoch)), ) log_string("Model saved in file: %s" % save_path) print("Model saved in file: %s" % save_path) # Save the variables to disk. if epoch % 10 == 0: save_path = saver.save( sess, os.path.join(PARAMS["logdir"], "model.ckpt")) log_string("Model saved in file: %s" % save_path) print("Model saved in file: %s" % save_path) # Kill the process, close the file and exit stacker.terminate() LOG_FOUT.close() sys.exit()
def train(FLAGS): # first make bag of words worddict = WordDict(files=[FLAGS.train_file, FLAGS.test_file], k=FLAGS.k, logpath=FLAGS.log_dir) FLAGS.word_size = worddict.size # read data dataset = DataSet(fpath=FLAGS.train_file, n_classes=FLAGS.num_classes, wd=worddict, need_shuffle=True) with tf.Graph().as_default(): global_step = tf.placeholder(tf.int32) placeholders = get_placeholders(FLAGS) pred = inference(placeholders['data'], FLAGS, for_training=True) tf.losses.softmax_cross_entropy(placeholders['labels'], pred) loss = tf.losses.get_total_loss() _acc_op = tf.equal(tf.argmax(pred, 1), tf.argmax(placeholders['labels'], 1)) acc_op = tf.reduce_mean(tf.cast(_acc_op, tf.float32)) train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss) # Create a saver. saver = tf.train.Saver(max_to_keep=None) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if tf.train.checkpoint_exists(FLAGS.prev_checkpoint_path): restorer = tf.train.Saver() restorer.restore(sess, FLAGS.prev_checkpoint_path) logging( '%s: Pre-trained model restored from %s' % (datetime.now(), FLAGS.prev_checkpoint_path), FLAGS) step = int( FLAGS.prev_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 else: step = 0 for data, labels in dataset.iter_once(FLAGS.batch_size): start_time = time.time() _, loss_val, acc_val = sess.run( [train_op, loss, acc_op], feed_dict={ placeholders['data']: data, placeholders['labels']: labels, global_step: step }) duration = time.time() - start_time assert not np.isnan(loss_val), 'Model diverge' if step > 0 and step % FLAGS.log_interval == 0: examples_per_sec = FLAGS.batch_size / float(duration) format_str = ( '%s: step %d, loss = %.2f, acc = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') logging( format_str % (datetime.now(), step, loss_val, acc_val, examples_per_sec, duration), FLAGS) if step > 0 and step % FLAGS.save_interval == 0: saver.save(sess, FLAGS.checkpoint_path, global_step=step) # counter step += 1 # save for last saver.save(sess, FLAGS.checkpoint_path, global_step=step - 1)
def test(FLAGS): # first make bag of words worddict = WordDict(FLAGS.embedding_file) FLAGS.word_size = worddict.size # read data dataset = DataSet(fpath=FLAGS.test_file, n_classes=FLAGS.num_classes, wd=worddict, need_shuffle=False) with tf.Graph().as_default(): # placeholder placeholders = get_placeholders(FLAGS) # get inference pred = inference(placeholders['data'], FLAGS, for_training=False) # calculate prediction pred_label_op = tf.argmax(pred, 1) label_op = tf.argmax(placeholders['labels'], 1) _hit_op = tf.equal(pred_label_op, label_op) hit_op = tf.reduce_sum(tf.cast(_hit_op, tf.float32)) # create saver saver = tf.train.Saver() # summary summary_op = tf.summary.merge_all() with tf.Session() as sess: # load model ckpt = tf.train.latest_checkpoint( os.path.dirname(FLAGS.checkpoint_path)) if tf.train.checkpoint_exists(ckpt): saver.restore(sess, ckpt) global_step = ckpt.split('/')[-1].split('-')[-1] logging( 'Succesfully loaded model from %s at step=%s.' % (ckpt, global_step), FLAGS) else: logging("[ERROR] Checkpoint not exist", FLAGS) return # summary writer summary_writer = tf.summary.FileWriter(FLAGS.log_dir, graph=sess.graph) # iter batch hit_count = 0.0 total_count = 0 results = [] logging("%s: starting test." % (datetime.now()), FLAGS) start_time = time.time() total_batch_size = math.ceil(dataset._num_data / FLAGS.batch_size) for step, (data, labels) in enumerate( dataset.iter_once(FLAGS.batch_size)): hits, pred, lb = sess.run([hit_op, pred_label_op, label_op], feed_dict={ placeholders['data']: data, placeholders['labels']: labels }) hit_count += np.sum(hits) total_count += len(data) for i, p in enumerate(pred): results.append((p, lb[i])) if step % FLAGS.log_interval == 0: duration = time.time() - start_time sec_per_batch = duration / FLAGS.log_interval examples_per_sec = FLAGS.batch_size / sec_per_batch logging( '%s: [%d batches out of %d] (%.1f examples/sec; %.3f' 'sec/batch)' % (datetime.now(), step, total_batch_size, examples_per_sec, sec_per_batch), FLAGS) start_time = time.time() # micro precision logging( "%s: micro-precision = %.5f" % (datetime.now(), (hit_count / total_count)), FLAGS) # write result outpath = os.path.join(FLAGS.log_dir, "out.txt") with open(outpath, 'w') as fw: for p, l in results: fw.write("%d\t%d\n" % (int(l), int(p)))
def test(FLAGS): # read data dataset = DataSet(fpath=FLAGS.test_file, seqlen=FLAGS.seq_len, n_classes=FLAGS.num_classes, need_shuffle=False) FLAGS.charset_size = dataset.charset_size with tf.Graph().as_default(): # placeholder placeholders = get_placeholders(FLAGS) # get inference pred, layers = inference(placeholders['data'], FLAGS, for_training=False) # calculate prediction _hit_op = tf.equal(tf.argmax(pred, 1), tf.argmax(placeholders['labels'], 1)) hit_op = tf.reduce_sum(tf.cast(_hit_op, tf.float32)) # create saver saver = tf.train.Saver() # argmax of hidden1 h1_argmax_ops = [] for op in layers['conv']: h1_argmax_ops.append(tf.argmax(op, axis=2)) with tf.Session() as sess: # load model ckpt = tf.train.latest_checkpoint( os.path.dirname(FLAGS.checkpoint_path)) if tf.train.checkpoint_exists(ckpt): saver.restore(sess, ckpt) global_step = ckpt.split('/')[-1].split('-')[-1] print('Succesfully loaded model from %s at step=%s.' % (ckpt, global_step)) else: print("[ERROR] Checkpoint not exist") return # iter batch hit_count = 0.0 total_count = 0 # top_matches = [ ([], []) ] * FLAGS.hidden1 # top 100 matching proteins wlens = [4, 8, 12, 16, 20] hsize = int(FLAGS.hidden1 / 5) motif_matches = (defaultdict(list), defaultdict(list)) print("%s: starting test." % (datetime.now())) start_time = time.time() total_batch_size = math.ceil(dataset._num_data / FLAGS.batch_size) for step, (data, labels, raws) in enumerate( dataset.iter_once(FLAGS.batch_size, with_raw=True)): res_run = sess.run([hit_op, h1_argmax_ops] + layers['conv'], feed_dict={ placeholders['data']: data, placeholders['labels']: labels }) hits = res_run[0] max_idxs = res_run[1] # shape = (wlens, N, 1, # of filters) motif_filters = res_run[2:] # mf.shape = (N, 1, l-w+1, # of filters) for i in range(len(motif_filters)): s = motif_filters[i].shape motif_filters[i] = np.transpose(motif_filters[i], (0, 1, 3, 2)).reshape( (s[0], s[3], s[2])) # mf.shape = (N, # of filters, l-w+1) for gidx, mf in enumerate(motif_filters): wlen = wlens[gidx] for ridx, row in enumerate(mf): for fidx, vals in enumerate(row): # for each filter, get max value and it's index max_idx = max_idxs[gidx][ridx][0][fidx] # max_idx = np.argmax(vals) max_val = vals[max_idx] hidx = gidx * hsize + fidx if max_val > 0: # get sequence rawseq = raws[ridx][1] subseq = rawseq[max_idx:max_idx + wlen] # heappush( top_matches[hidx], (max_val, subseq) ) motif_matches[0][hidx].append(max_val) motif_matches[1][hidx].append(subseq) # motif_matches[gidx][fidx][0].append( max_val ) # motif_matches[gidx][fidx][1].append( subseq ) hit_count += np.sum(hits) total_count += len(data) # print("total:%d" % total_count) if step % FLAGS.log_interval == 0: duration = time.time() - start_time sec_per_batch = duration / FLAGS.log_interval examples_per_sec = FLAGS.batch_size / sec_per_batch print('%s: [%d batches out of %d] (%.1f examples/sec; %.3f' 'sec/batch)' % (datetime.now(), step, total_batch_size, examples_per_sec, sec_per_batch)) start_time = time.time() # if step > 10: # break # # micro precision # print("%s: micro-precision = %.5f" % # (datetime.now(), (hit_count/total_count))) ### sort top lists print('%s: write result to file' % (datetime.now())) for fidx in motif_matches[0]: val_lst = motif_matches[0][fidx] seq_lst = motif_matches[1][fidx] # top k k = wlens[int(fidx / hsize)] * 25 l = min(k, len(val_lst)) * -1 tidxs = np.argpartition(val_lst, l)[l:] with open( "/home/kimlab/project/CCC/tmp/logos/test/p%d.txt" % fidx, 'w') as fw: for idx in tidxs: fw.write("%f\t%s\n" % (val_lst[idx], seq_lst[idx])) if fidx % 50 == 0: print('%s: [%d filters out of %d]' % (datetime.now(), fidx, FLAGS.hidden1))
def test(FLAGS): # read data dataset = DataSet(fpath=FLAGS.test_file, seqlen=FLAGS.seq_len, n_classes=FLAGS.num_classes, need_shuffle=False) # set character set size FLAGS.charset_size = dataset.charset_size with tf.Graph().as_default(): # placeholder placeholders = get_placeholders(FLAGS) # get inference pred, layers = inference(placeholders['data'], FLAGS, for_training=False) prob = tf.nn.softmax(pred) # calculate prediction _hit_op = tf.equal(tf.argmax(pred, 1), tf.argmax(placeholders['labels'], 1)) hit_op = tf.reduce_sum(tf.cast(_hit_op, tf.float32)) # create saver saver = tf.train.Saver() # summary summary_op = tf.summary.merge_all() with tf.Session() as sess: # load model ckpt = tf.train.latest_checkpoint( os.path.dirname(FLAGS.checkpoint_path)) if tf.train.checkpoint_exists(ckpt): saver.restore(sess, ckpt) global_step = ckpt.split('/')[-1].split('-')[-1] logging( 'Succesfully loaded model from %s at step=%s.' % (ckpt, global_step), FLAGS) else: logging("[ERROR] Checkpoint not exist", FLAGS) return # iter batch hit_count = 0.0 total_count = 0 pred_list = [] label_list = [] logging("%s: starting test." % (datetime.now()), FLAGS) start_time = time.time() total_batch_size = math.ceil(dataset._num_data / FLAGS.batch_size) for step, (data, labels) in enumerate( dataset.iter_once(FLAGS.batch_size)): hits, pred_val = sess.run([hit_op, prob], feed_dict={ placeholders['data']: data, placeholders['labels']: labels }) hit_count += np.sum(hits) total_count += len(data) for i, p in enumerate(pred_val): pred_list.append(p[0]) label_list.append(labels[i][0]) if step % FLAGS.log_interval == 0: duration = time.time() - start_time sec_per_batch = duration / FLAGS.log_interval examples_per_sec = FLAGS.batch_size / sec_per_batch logging( '%s: [%d batches out of %d] (%.1f examples/sec; %.3f' 'sec/batch)' % (datetime.now(), step, total_batch_size, examples_per_sec, sec_per_batch), FLAGS) start_time = time.time() # micro precision # logging("%s: micro-precision = %.5f" % # (datetime.now(), (hit_count/total_count)), FLAGS) auc_val = roc_auc_score(label_list, pred_list) logging( "%s: micro-precision = %.5f, auc = %.5f" % (datetime.now(), (hit_count / total_count), auc_val), FLAGS) if FLAGS.save_prediction: with open(FLAGS.save_prediction, 'w') as fout: with open(FLAGS.test_file) as f: for i, line in enumerate(f.readlines()): line = line.strip() id = line.split('\t')[2] label = label_list[i] predict = pred_list[i] print('\t'.join([id, predict, label]), file=fout)
def test( FLAGS ): # read data dataset = DataSet( fpath = FLAGS.test_file, seqlen = FLAGS.seq_len, n_classes = FLAGS.num_classes, need_shuffle = False ) FLAGS.charset_size = dataset.charset_size with tf.Graph().as_default(): # placeholder placeholders = get_placeholders(FLAGS) # get inference pred, layers = inference( placeholders['data'], FLAGS, for_training=False ) # calculate prediction label_op = tf.argmax(pred, 1) prob_op = tf.nn.softmax(pred) # _hit_op = tf.equal( tf.argmax(pred, 1), tf.argmax(placeholders['labels'], 1)) # hit_op = tf.reduce_sum( tf.cast( _hit_op ,tf.float32 ) ) # create saver saver = tf.train.Saver() # argmax of hidden1 h1_argmax_ops = [] for op in layers['conv']: h1_argmax_ops.append(tf.argmax(op, axis=2)) with tf.Session() as sess: # load model # ckpt = tf.train.latest_checkpoint( os.path.dirname( FLAGS.checkpoint_path ) ) ckpt = FLAGS.checkpoint_path if tf.train.checkpoint_exists( ckpt ): saver.restore( sess, ckpt ) global_step = ckpt.split('/')[-1].split('-')[-1] print('Succesfully loaded model from %s at step=%s.' % (ckpt, global_step)) else: print("[ERROR] Checkpoint not exist") return # iter batch hit_count = 0.0 total_count = 0 wlens = FLAGS.window_lengths hsizes = FLAGS.num_windows motif_matches = (defaultdict(list), defaultdict(list)) pred_labels = [] pred_prob = [] print("%s: starting test." % (datetime.now())) start_time = time.time() total_batch_size = math.ceil( dataset._num_data / FLAGS.batch_size ) for step, (data, labels, raws) in enumerate(dataset.iter_once( FLAGS.batch_size, with_raw=True )): res_run = sess.run( [label_op, prob_op, h1_argmax_ops] + layers['conv'], feed_dict={ placeholders['data']: data, placeholders['labels']: labels }) pred_label = res_run[0] pred_arr = res_run[1] max_idxs = res_run[2] # shape = (wlens, N, 1, # of filters) motif_filters = res_run[3:] for i, l in enumerate(pred_label): pred_labels.append(l) pred_prob.append( pred_arr[i][ 388 ] ) # mf.shape = (N, 1, l-w+1, # of filters) for i in range(len(motif_filters)): s = motif_filters[i].shape motif_filters[i] = np.transpose( motif_filters[i], (0, 1, 3, 2) ).reshape( (s[0], s[3], s[2]) ) # mf.shape = (N, # of filters, l-w+1) for gidx, mf in enumerate(motif_filters): wlen = wlens[gidx] hsize = hsizes[gidx] for ridx, row in enumerate(mf): for fidx, vals in enumerate(row): # for each filter, get max value and it's index max_idx = max_idxs[gidx][ridx][0][fidx] # max_idx = np.argmax(vals) max_val = vals[ max_idx ] hidx = gidx * hsize + fidx if max_val > 0: # get sequence rawseq = raws[ridx][1] subseq = rawseq[ max_idx : max_idx+wlen ] # heappush( top_matches[hidx], (max_val, subseq) ) motif_matches[0][hidx].append( max_val ) motif_matches[1][hidx].append( subseq ) # motif_matches[gidx][fidx][0].append( max_val ) # motif_matches[gidx][fidx][1].append( subseq ) # hit_count += np.sum( hits ) total_count += len( data ) # print("total:%d" % total_count) if step % FLAGS.log_interval == 0: duration = time.time() - start_time sec_per_batch = duration / FLAGS.log_interval examples_per_sec = FLAGS.batch_size / sec_per_batch print('%s: [%d batches out of %d] (%.1f examples/sec; %.3f' 'sec/batch)' % (datetime.now(), step, total_batch_size, examples_per_sec, sec_per_batch)) start_time = time.time() # if step > 10: # break # # micro precision # print("%s: micro-precision = %.5f" % # (datetime.now(), (hit_count/total_count))) print(pred_labels) ### sort top lists # report whose activation was higher mean_acts = {} on_acts = {} print('%s: write result to file' % (datetime.now()) ) for fidx in motif_matches[0]: val_lst = motif_matches[0][fidx] seq_lst = motif_matches[1][fidx] # top k # k = wlens[ int(fidx / hsize) ] * 25 k = 30 # l = min(k, len(val_lst)) * -1 l = len(val_lst) * -1 tidxs = np.argpartition(val_lst, l)[l:] # tracking acts acts = 0.0 opath = os.path.join(FLAGS.motif_outpath, "p%d.txt"%fidx) with open(opath, 'w') as fw: for idx in tidxs: fw.write("%f\t%s\n" % (val_lst[idx], seq_lst[idx]) ) acts += val_lst[idx] mean_acts[fidx] = acts / l * -1 on_acts[fidx] = len(val_lst) if fidx % 50 == 0: print('%s: [%d filters out of %d]' % (datetime.now(), fidx, sum(FLAGS.num_windows))) # print(len(val_lst)) # report mean acts with open(os.path.join(FLAGS.motif_outpath, "report.txt"), 'w') as fw: for i in sorted(on_acts, key=on_acts.get, reverse=True): fw.write("%f\t%f\t%d\n" % (on_acts[i] / total_count, mean_acts[i], i)) with open(os.path.join(FLAGS.motif_outpath, "predictions.txt"), 'w') as fw: for i, p in enumerate(pred_labels): fw.write("%s\t%f\n" % (str(p), pred_prob[i]))
def train(FLAGS): # read data dataset = DataSet(fpath=FLAGS.train_file, seqlen=FLAGS.seq_len, n_classes=FLAGS.num_classes, need_shuffle=True) # set character set size FLAGS.charset_size = dataset.charset_size with tf.Graph().as_default(): # get placeholders global_step = tf.placeholder(tf.int32) placeholders = get_placeholders(FLAGS) # prediction pred, layers = inference(placeholders['data'], FLAGS, for_training=True) # loss # slim.losses.softmax_cross_entropy(pred, placeholders['labels']) # class_weight = tf.constant([[1.0, 5.0]]) # weight_per_label = tf.transpose( tf.matmul(placeholders['labels'] # , tf.transpose(class_weight)) ) # loss = tf.multiply(weight_per_label, # tf.nn.softmax_cross_entropy_with_logits(labels=placeholders['labels'], logits=pred)) # loss = tf.losses.compute_weighted_loss(loss) tf.losses.softmax_cross_entropy(placeholders['labels'], pred) loss = tf.losses.get_total_loss() # accuracy _acc_op = tf.equal(tf.argmax(pred, 1), tf.argmax(placeholders['labels'], 1)) acc_op = tf.reduce_mean(tf.cast(_acc_op, tf.float32)) # optimization train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss) # train_op = tf.train.RMSPropOptimizer( FLAGS.learning_rate ).minimize( loss ) # Create a saver. saver = tf.train.Saver(max_to_keep=None) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if tf.train.checkpoint_exists(FLAGS.prev_checkpoint_path): if FLAGS.fine_tuning: logging('%s: Fine Tuning Experiment!' % (datetime.now()), FLAGS) restore_variables = slim.get_variables_to_restore( exclude=FLAGS.fine_tuning_layers) restorer = tf.train.Saver(restore_variables) else: restorer = tf.train.Saver() restorer.restore(sess, FLAGS.prev_checkpoint_path) logging( '%s: Pre-trained model restored from %s' % (datetime.now(), FLAGS.prev_checkpoint_path), FLAGS) step = int( FLAGS.prev_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 else: step = 0 # iter epoch # for data, labels in dataset.iter_batch( FLAGS.batch_size, 5 ): for data, labels in dataset.iter_once(FLAGS.batch_size): start_time = time.time() _, loss_val, acc_val = sess.run( [train_op, loss, acc_op], feed_dict={ placeholders['data']: data, placeholders['labels']: labels, global_step: step }) duration = time.time() - start_time assert not np.isnan(loss_val), 'Model diverge' # logging if step > 0 and step % FLAGS.log_interval == 0: examples_per_sec = FLAGS.batch_size / float(duration) format_str = ( '%s: step %d, loss = %.2f, acc = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') logging( format_str % (datetime.now(), step, loss_val, acc_val, examples_per_sec, duration), FLAGS) # save model if step > 0 and step % FLAGS.save_interval == 0: saver.save(sess, FLAGS.checkpoint_path, global_step=step) # counter step += 1 # save for last saver.save(sess, FLAGS.checkpoint_path, global_step=step - 1)