def __init__(self, checkpoint_path, num_classes, hyper_params):
        # Get ops from graph
        with tf.device("/gpu:0"):
            # Placeholders
            pl_sparse_points_centered_batched, _, _ = model.get_placeholders(
                hyper_params["num_point"], hyperparams=hyper_params)
            pl_is_training = tf.placeholder(tf.bool, shape=())

            # Prediction
            pred, _ = model.get_model(
                pl_sparse_points_centered_batched,
                pl_is_training,
                num_classes,
                hyperparams=hyper_params,
            )
            sparse_labels_batched = tf.argmax(pred, axis=2)
            # (1, num_sparse_points) -> (num_sparse_points,)
            sparse_labels = tf.reshape(sparse_labels_batched, [-1])
            sparse_labels = tf.cast(sparse_labels, tf.int32)

            # Saver
            saver = tf.train.Saver()

            # Graph for interpolating labels
            # Assuming batch_size == 1 for simplicity
            pl_sparse_points_batched = tf.placeholder(tf.float32,
                                                      (None, None, 3))
            sparse_points = tf.reshape(pl_sparse_points_batched, [-1, 3])
            pl_dense_points = tf.placeholder(tf.float32, (None, 3))
            pl_knn = tf.placeholder(tf.int32, ())
            dense_labels, dense_colors = interpolate_label_with_color(
                sparse_points, sparse_labels, pl_dense_points, pl_knn)

        self.ops = {
            "pl_sparse_points_centered_batched":
            pl_sparse_points_centered_batched,
            "pl_sparse_points_batched": pl_sparse_points_batched,
            "pl_dense_points": pl_dense_points,
            "pl_is_training": pl_is_training,
            "pl_knn": pl_knn,
            "dense_labels": dense_labels,
            "dense_colors": dense_colors,
        }

        # Restore checkpoint to session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        self.sess = tf.Session(config=config)
        saver.restore(self.sess, checkpoint_path)
        print("Model restored")
    def __init__(self, checkpoint_path, num_classes, hyper_params):
        # Get ops from graph
        with tf.device("/gpu:0"):
            # Placeholder
            pl_points, _, _ = model.get_placeholders(hyper_params["num_point"],
                                                     hyperparams=hyper_params)
            pl_is_training = tf.placeholder(tf.bool, shape=())
            print("pl_points shape", tf.shape(pl_points))

            # Prediction
            pred, _ = model.get_model(pl_points,
                                      pl_is_training,
                                      num_classes,
                                      hyperparams=hyper_params)

            # Saver
            saver = tf.train.Saver()

            # Graph for interpolating labels
            # Assuming batch_size == 1 for simplicity
            pl_sparse_points = tf.placeholder(tf.float32, (None, 3))
            pl_sparse_labels = tf.placeholder(tf.int32, (None, ))
            pl_dense_points = tf.placeholder(tf.float32, (None, 3))
            pl_knn = tf.placeholder(tf.int32, ())
            sparse_indices = interpolate_label(pl_sparse_points,
                                               pl_sparse_labels,
                                               pl_dense_points, pl_knn)

        self.ops = {
            "pl_points": pl_points,
            "pl_is_training": pl_is_training,
            "pred": pred,
            "pl_sparse_points": pl_sparse_points,
            "pl_sparse_labels": pl_sparse_labels,
            "pl_dense_points": pl_dense_points,
            "pl_knn": pl_knn,
            "sparse_indices": sparse_indices,
        }

        # Restore checkpoint to session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        self.sess = tf.Session(config=config)
        saver.restore(self.sess, checkpoint_path)
        print("Model restored")
Exemplo n.º 3
0
def test(FLAGS):
    # read data
    dataset = DataSet(fpath=FLAGS.test_file,
                      seqlen=FLAGS.seq_len,
                      n_classes=FLAGS.num_classes,
                      num_feature=FLAGS.num_feature,
                      is_raw=FLAGS.is_raw,
                      need_shuffle=False)
    # set character set size
    FLAGS.charset_size = dataset.charset_size

    with tf.Graph().as_default():
        # placeholder
        placeholders = get_placeholders(FLAGS)

        # get inference
        pred, layers = inference(placeholders['data'],
                                 FLAGS,
                                 for_training=False)

        prob = tf.nn.softmax(pred)
        # calculate prediction
        _hit_op = tf.equal(tf.argmax(pred, 1),
                           tf.argmax(placeholders['labels'], 1))
        hit_op = tf.reduce_sum(tf.cast(_hit_op, tf.float32))

        # create saver
        saver = tf.train.Saver()

        # summary
        summary_op = tf.summary.merge_all()

        with tf.Session() as sess:
            # load model
            ckpt = tf.train.latest_checkpoint(
                os.path.dirname(FLAGS.checkpoint_path))
            if tf.train.checkpoint_exists(ckpt):
                saver.restore(sess, ckpt)
                global_step = ckpt.split('/')[-1].split('-')[-1]
                logging(
                    'Succesfully loaded model from %s at step=%s.' %
                    (ckpt, global_step), FLAGS)
            else:
                logging("[ERROR] Checkpoint not exist", FLAGS)
                return

            # iter batch
            hit_count = 0.0
            total_count = 0
            pred_list = []
            label_list = []

            logging("%s: starting test." % (datetime.now()), FLAGS)
            start_time = time.time()
            total_batch_size = math.ceil(dataset._num_data / FLAGS.batch_size)

            for step, (data, labels) in enumerate(
                    dataset.iter_once(FLAGS.batch_size)):
                hits, pred_val = sess.run([hit_op, prob],
                                          feed_dict={
                                              placeholders['data']: data,
                                              placeholders['labels']: labels
                                          })

                hit_count += np.sum(hits)
                total_count += len(data)

                for i, p in enumerate(pred_val):
                    pred_list.append(p[0])
                    label_list.append(labels[i][0])

                if step % FLAGS.log_interval == 0:
                    duration = time.time() - start_time
                    sec_per_batch = duration / FLAGS.log_interval
                    examples_per_sec = FLAGS.batch_size / sec_per_batch
                    logging(
                        '%s: [%d batches out of %d] (%.1f examples/sec; %.3f'
                        'sec/batch)' % (datetime.now(), step, total_batch_size,
                                        examples_per_sec, sec_per_batch),
                        FLAGS)
                    start_time = time.time()

            # micro precision
            # logging("%s: micro-precision = %.5f" %
            #       (datetime.now(), (hit_count/total_count)), FLAGS)

            auc_val = roc_auc_score(label_list, pred_list)
            logging(
                "%s: micro-precision = %.5f, auc = %.5f" %
                (datetime.now(), (hit_count / total_count), auc_val), FLAGS)

            pred_y = [1 if i > FLAGS.threshold else 0 for i in pred_list]

            TN, FP, FN, TP = confusion_matrix(label_list,
                                              pred_y,
                                              labels=[1, 0]).ravel()

            Sensitivity = round((TP / (TP + FN)), 4) if TP + FN > 0 else 0
            Specificity = round(TN / (FP + TN), 4) if FP + TN > 0 else 0
            Precision = round(TP / (TP + FP), 4) if TP + FP > 0 else 0
            Accuracy = round((TP + TN) / (TP + FP + TN + FN), 4)
            MCC = round(
                ((TP * TN) - (FP * FN)) / (math.sqrt(
                    (TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))), 4
            ) if TP + FP > 0 and FP + TN > 0 and TP + FN and TN + FN else 0
            F1 = round((2 * TP) / ((2 * TP) + FP + FN), 4)

            fout = open(FLAGS.out_file, 'a')

            Prec = round(hit_count / total_count, 4)
            AUC = round(auc_val, 4)

            fout.write(
                f"{global_step},{datetime.now()},{TP},{FP},{TN},{FN},{Sensitivity},{Specificity},{Precision},{Accuracy},{MCC},{F1},{AUC}\n"
            )

            logging(
                f"TP={TP}, FP={FP}, TN={TN}, FN={FN}, Sens={Sensitivity}, Spec={Specificity}, Prec={Precision}, Acc={Accuracy}, MCC={MCC}, F1={F1}, AUC={auc_val}",
                FLAGS)

            fout.close()
Exemplo n.º 4
0
def train():
    """Train the model on a single GPU
    """
    with tf.Graph().as_default():
        stacker, stack_validation, stack_train = init_stacking()

        with tf.device("/gpu:" + str(PARAMS["gpu"])):
            pointclouds_pl, labels_pl, smpws_pl = model.get_placeholders(
                PARAMS["num_point"], hyperparams=PARAMS)
            is_training_pl = tf.compat.v1.placeholder(tf.bool, shape=())

            # Note the global_step=batch parameter to minimize.
            # That tells the optimizer to helpfully increment the 'batch' parameter for
            # you every time it trains.
            batch = tf.Variable(0)
            bn_decay = get_bn_decay(batch)
            tf.summary.scalar("bn_decay", bn_decay)

            print("--- Get model and loss")
            # Get model and loss
            pred, end_points = model.get_model(
                pointclouds_pl,
                is_training_pl,
                NUM_CLASSES,
                hyperparams=PARAMS,
                bn_decay=bn_decay,
            )
            loss = model.get_loss(pred, labels_pl, smpws_pl, end_points)
            tf.summary.scalar("loss", loss)

            # Compute accuracy
            correct = tf.equal(tf.argmax(pred, 2),
                               tf.compat.v1.to_int64(labels_pl))
            accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(
                PARAMS["batch_size"] * PARAMS["num_point"])
            tf.summary.scalar("accuracy", accuracy)

            # Computer mean intersection over union
            mean_intersection_over_union, update_iou_op = tf.compat.v1.metrics.mean_iou(
                tf.compat.v1.to_int32(labels_pl),
                tf.compat.v1.to_int32(tf.argmax(pred, 2)), NUM_CLASSES)
            tf.summary.scalar(
                "mIoU", tf.compat.v1.to_float(mean_intersection_over_union))

            print("--- Get training operator")
            # Get training operator
            learning_rate = get_learning_rate(batch)
            tf.summary.scalar("learning_rate", learning_rate)
            if PARAMS["optimizer"] == "momentum":
                optimizer = tf.train.MomentumOptimizer(
                    learning_rate, momentum=PARAMS["momentum"])
            else:
                assert PARAMS["optimizer"] == "adam"
                optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
            train_op = optimizer.minimize(loss, global_step=batch)

            # Add ops to save and restore all the variables.
            saver = tf.compat.v1.train.Saver()

        # Create a session
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.compat.v1.Session(config=config)

        # Add summary writers
        merged = tf.compat.v1.summary.merge_all()
        train_writer = tf.compat.v1.summary.FileWriter(
            os.path.join(PARAMS["logdir"], "train"), sess.graph)
        validation_writer = tf.compat.v1.summary.FileWriter(
            os.path.join(PARAMS["logdir"], "validation"), sess.graph)

        # Init variables
        sess.run(tf.compat.v1.global_variables_initializer())
        sess.run(
            tf.compat.v1.local_variables_initializer())  # important for mIoU

        ops = {
            "pointclouds_pl": pointclouds_pl,
            "labels_pl": labels_pl,
            "smpws_pl": smpws_pl,
            "is_training_pl": is_training_pl,
            "pred": pred,
            "loss": loss,
            "train_op": train_op,
            "merged": merged,
            "step": batch,
            "end_points": end_points,
            "update_iou": update_iou_op,
        }

        # Train for hyper_params["max_epoch"] epochs
        best_acc = 0
        for epoch in range(PARAMS["max_epoch"]):
            print("in epoch", epoch)
            print("max_epoch", PARAMS["max_epoch"])

            log_string("**** EPOCH %03d ****" % (epoch))
            sys.stdout.flush()

            # Train one epoch
            train_one_epoch(sess, ops, train_writer, stack_train)

            # Evaluate, save, and compute the accuracy
            if epoch % 5 == 0:
                acc = eval_one_epoch(sess, ops, validation_writer,
                                     stack_validation)

            if acc > best_acc:
                best_acc = acc
                save_path = saver.save(
                    sess,
                    os.path.join(PARAMS["logdir"],
                                 "best_model_epoch_%03d.ckpt" % (epoch)),
                )
                log_string("Model saved in file: %s" % save_path)
                print("Model saved in file: %s" % save_path)

            # Save the variables to disk.
            if epoch % 10 == 0:
                save_path = saver.save(
                    sess, os.path.join(PARAMS["logdir"], "model.ckpt"))
                log_string("Model saved in file: %s" % save_path)
                print("Model saved in file: %s" % save_path)

        # Kill the process, close the file and exit
        stacker.terminate()
        LOG_FOUT.close()
        sys.exit()
Exemplo n.º 5
0
def train(FLAGS):
    # first make bag of words
    worddict = WordDict(files=[FLAGS.train_file, FLAGS.test_file],
                        k=FLAGS.k,
                        logpath=FLAGS.log_dir)

    FLAGS.word_size = worddict.size

    # read data
    dataset = DataSet(fpath=FLAGS.train_file,
                      n_classes=FLAGS.num_classes,
                      wd=worddict,
                      need_shuffle=True)

    with tf.Graph().as_default():
        global_step = tf.placeholder(tf.int32)
        placeholders = get_placeholders(FLAGS)

        pred = inference(placeholders['data'], FLAGS, for_training=True)

        tf.losses.softmax_cross_entropy(placeholders['labels'], pred)
        loss = tf.losses.get_total_loss()

        _acc_op = tf.equal(tf.argmax(pred, 1),
                           tf.argmax(placeholders['labels'], 1))
        acc_op = tf.reduce_mean(tf.cast(_acc_op, tf.float32))

        train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)

        # Create a saver.
        saver = tf.train.Saver(max_to_keep=None)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            if tf.train.checkpoint_exists(FLAGS.prev_checkpoint_path):
                restorer = tf.train.Saver()
                restorer.restore(sess, FLAGS.prev_checkpoint_path)
                logging(
                    '%s: Pre-trained model restored from %s' %
                    (datetime.now(), FLAGS.prev_checkpoint_path), FLAGS)
                step = int(
                    FLAGS.prev_checkpoint_path.split('/')[-1].split('-')
                    [-1]) + 1
            else:
                step = 0

            for data, labels in dataset.iter_once(FLAGS.batch_size):
                start_time = time.time()
                _, loss_val, acc_val = sess.run(
                    [train_op, loss, acc_op],
                    feed_dict={
                        placeholders['data']: data,
                        placeholders['labels']: labels,
                        global_step: step
                    })
                duration = time.time() - start_time

                assert not np.isnan(loss_val), 'Model diverge'

                if step > 0 and step % FLAGS.log_interval == 0:
                    examples_per_sec = FLAGS.batch_size / float(duration)
                    format_str = (
                        '%s: step %d, loss = %.2f, acc = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    logging(
                        format_str % (datetime.now(), step, loss_val, acc_val,
                                      examples_per_sec, duration), FLAGS)

                if step > 0 and step % FLAGS.save_interval == 0:
                    saver.save(sess, FLAGS.checkpoint_path, global_step=step)

                # counter
                step += 1

            # save for last
            saver.save(sess, FLAGS.checkpoint_path, global_step=step - 1)
Exemplo n.º 6
0
def test(FLAGS):
    # first make bag of words
    worddict = WordDict(FLAGS.embedding_file)

    FLAGS.word_size = worddict.size

    # read data
    dataset = DataSet(fpath=FLAGS.test_file,
                      n_classes=FLAGS.num_classes,
                      wd=worddict,
                      need_shuffle=False)

    with tf.Graph().as_default():
        # placeholder
        placeholders = get_placeholders(FLAGS)

        # get inference
        pred = inference(placeholders['data'], FLAGS, for_training=False)

        # calculate prediction
        pred_label_op = tf.argmax(pred, 1)
        label_op = tf.argmax(placeholders['labels'], 1)
        _hit_op = tf.equal(pred_label_op, label_op)
        hit_op = tf.reduce_sum(tf.cast(_hit_op, tf.float32))

        # create saver
        saver = tf.train.Saver()

        # summary
        summary_op = tf.summary.merge_all()

        with tf.Session() as sess:
            # load model
            ckpt = tf.train.latest_checkpoint(
                os.path.dirname(FLAGS.checkpoint_path))
            if tf.train.checkpoint_exists(ckpt):
                saver.restore(sess, ckpt)
                global_step = ckpt.split('/')[-1].split('-')[-1]
                logging(
                    'Succesfully loaded model from %s at step=%s.' %
                    (ckpt, global_step), FLAGS)
            else:
                logging("[ERROR] Checkpoint not exist", FLAGS)
                return

            # summary writer
            summary_writer = tf.summary.FileWriter(FLAGS.log_dir,
                                                   graph=sess.graph)

            # iter batch
            hit_count = 0.0
            total_count = 0
            results = []

            logging("%s: starting test." % (datetime.now()), FLAGS)
            start_time = time.time()
            total_batch_size = math.ceil(dataset._num_data / FLAGS.batch_size)

            for step, (data, labels) in enumerate(
                    dataset.iter_once(FLAGS.batch_size)):
                hits, pred, lb = sess.run([hit_op, pred_label_op, label_op],
                                          feed_dict={
                                              placeholders['data']: data,
                                              placeholders['labels']: labels
                                          })

                hit_count += np.sum(hits)
                total_count += len(data)

                for i, p in enumerate(pred):
                    results.append((p, lb[i]))

                if step % FLAGS.log_interval == 0:
                    duration = time.time() - start_time
                    sec_per_batch = duration / FLAGS.log_interval
                    examples_per_sec = FLAGS.batch_size / sec_per_batch
                    logging(
                        '%s: [%d batches out of %d] (%.1f examples/sec; %.3f'
                        'sec/batch)' % (datetime.now(), step, total_batch_size,
                                        examples_per_sec, sec_per_batch),
                        FLAGS)
                    start_time = time.time()

            # micro precision
            logging(
                "%s: micro-precision = %.5f" % (datetime.now(),
                                                (hit_count / total_count)),
                FLAGS)

            # write result
            outpath = os.path.join(FLAGS.log_dir, "out.txt")
            with open(outpath, 'w') as fw:
                for p, l in results:
                    fw.write("%d\t%d\n" % (int(l), int(p)))
Exemplo n.º 7
0
def test(FLAGS):
    # read data
    dataset = DataSet(fpath=FLAGS.test_file,
                      seqlen=FLAGS.seq_len,
                      n_classes=FLAGS.num_classes,
                      need_shuffle=False)

    FLAGS.charset_size = dataset.charset_size

    with tf.Graph().as_default():
        # placeholder
        placeholders = get_placeholders(FLAGS)

        # get inference
        pred, layers = inference(placeholders['data'],
                                 FLAGS,
                                 for_training=False)

        # calculate prediction
        _hit_op = tf.equal(tf.argmax(pred, 1),
                           tf.argmax(placeholders['labels'], 1))
        hit_op = tf.reduce_sum(tf.cast(_hit_op, tf.float32))

        # create saver
        saver = tf.train.Saver()

        # argmax of hidden1
        h1_argmax_ops = []
        for op in layers['conv']:
            h1_argmax_ops.append(tf.argmax(op, axis=2))

        with tf.Session() as sess:
            # load model
            ckpt = tf.train.latest_checkpoint(
                os.path.dirname(FLAGS.checkpoint_path))
            if tf.train.checkpoint_exists(ckpt):
                saver.restore(sess, ckpt)
                global_step = ckpt.split('/')[-1].split('-')[-1]
                print('Succesfully loaded model from %s at step=%s.' %
                      (ckpt, global_step))
            else:
                print("[ERROR] Checkpoint not exist")
                return

            # iter batch
            hit_count = 0.0
            total_count = 0
            # top_matches = [ ([], []) ] * FLAGS.hidden1 # top 100 matching proteins
            wlens = [4, 8, 12, 16, 20]
            hsize = int(FLAGS.hidden1 / 5)
            motif_matches = (defaultdict(list), defaultdict(list))

            print("%s: starting test." % (datetime.now()))
            start_time = time.time()
            total_batch_size = math.ceil(dataset._num_data / FLAGS.batch_size)

            for step, (data, labels, raws) in enumerate(
                    dataset.iter_once(FLAGS.batch_size, with_raw=True)):
                res_run = sess.run([hit_op, h1_argmax_ops] + layers['conv'],
                                   feed_dict={
                                       placeholders['data']: data,
                                       placeholders['labels']: labels
                                   })

                hits = res_run[0]
                max_idxs = res_run[1]  # shape = (wlens, N, 1, # of filters)
                motif_filters = res_run[2:]

                # mf.shape = (N, 1, l-w+1, # of filters)
                for i in range(len(motif_filters)):
                    s = motif_filters[i].shape
                    motif_filters[i] = np.transpose(motif_filters[i],
                                                    (0, 1, 3, 2)).reshape(
                                                        (s[0], s[3], s[2]))

                # mf.shape = (N, # of filters, l-w+1)
                for gidx, mf in enumerate(motif_filters):
                    wlen = wlens[gidx]
                    for ridx, row in enumerate(mf):
                        for fidx, vals in enumerate(row):
                            # for each filter, get max value and it's index
                            max_idx = max_idxs[gidx][ridx][0][fidx]
                            # max_idx = np.argmax(vals)
                            max_val = vals[max_idx]

                            hidx = gidx * hsize + fidx

                            if max_val > 0:
                                # get sequence
                                rawseq = raws[ridx][1]
                                subseq = rawseq[max_idx:max_idx + wlen]
                                # heappush( top_matches[hidx], (max_val, subseq) )
                                motif_matches[0][hidx].append(max_val)
                                motif_matches[1][hidx].append(subseq)
                                # motif_matches[gidx][fidx][0].append( max_val )
                                # motif_matches[gidx][fidx][1].append( subseq )

                hit_count += np.sum(hits)
                total_count += len(data)
                # print("total:%d" % total_count)

                if step % FLAGS.log_interval == 0:
                    duration = time.time() - start_time
                    sec_per_batch = duration / FLAGS.log_interval
                    examples_per_sec = FLAGS.batch_size / sec_per_batch
                    print('%s: [%d batches out of %d] (%.1f examples/sec; %.3f'
                          'sec/batch)' %
                          (datetime.now(), step, total_batch_size,
                           examples_per_sec, sec_per_batch))
                    start_time = time.time()

                # if step > 10:
                #   break

            # # micro precision
            # print("%s: micro-precision = %.5f" %
            #       (datetime.now(), (hit_count/total_count)))

            ### sort top lists
            print('%s: write result to file' % (datetime.now()))
            for fidx in motif_matches[0]:
                val_lst = motif_matches[0][fidx]
                seq_lst = motif_matches[1][fidx]
                # top k
                k = wlens[int(fidx / hsize)] * 25
                l = min(k, len(val_lst)) * -1
                tidxs = np.argpartition(val_lst, l)[l:]
                with open(
                        "/home/kimlab/project/CCC/tmp/logos/test/p%d.txt" %
                        fidx, 'w') as fw:
                    for idx in tidxs:
                        fw.write("%f\t%s\n" % (val_lst[idx], seq_lst[idx]))

                if fidx % 50 == 0:
                    print('%s: [%d filters out of %d]' %
                          (datetime.now(), fidx, FLAGS.hidden1))
Exemplo n.º 8
0
def test(FLAGS):
    # read data
    dataset = DataSet(fpath=FLAGS.test_file,
                      seqlen=FLAGS.seq_len,
                      n_classes=FLAGS.num_classes,
                      need_shuffle=False)
    # set character set size
    FLAGS.charset_size = dataset.charset_size

    with tf.Graph().as_default():
        # placeholder
        placeholders = get_placeholders(FLAGS)

        # get inference
        pred, layers = inference(placeholders['data'],
                                 FLAGS,
                                 for_training=False)

        prob = tf.nn.softmax(pred)
        # calculate prediction
        _hit_op = tf.equal(tf.argmax(pred, 1),
                           tf.argmax(placeholders['labels'], 1))
        hit_op = tf.reduce_sum(tf.cast(_hit_op, tf.float32))

        # create saver
        saver = tf.train.Saver()

        # summary
        summary_op = tf.summary.merge_all()

        with tf.Session() as sess:
            # load model
            ckpt = tf.train.latest_checkpoint(
                os.path.dirname(FLAGS.checkpoint_path))
            if tf.train.checkpoint_exists(ckpt):
                saver.restore(sess, ckpt)
                global_step = ckpt.split('/')[-1].split('-')[-1]
                logging(
                    'Succesfully loaded model from %s at step=%s.' %
                    (ckpt, global_step), FLAGS)
            else:
                logging("[ERROR] Checkpoint not exist", FLAGS)
                return

            # iter batch
            hit_count = 0.0
            total_count = 0
            pred_list = []
            label_list = []

            logging("%s: starting test." % (datetime.now()), FLAGS)
            start_time = time.time()
            total_batch_size = math.ceil(dataset._num_data / FLAGS.batch_size)

            for step, (data, labels) in enumerate(
                    dataset.iter_once(FLAGS.batch_size)):
                hits, pred_val = sess.run([hit_op, prob],
                                          feed_dict={
                                              placeholders['data']: data,
                                              placeholders['labels']: labels
                                          })

                hit_count += np.sum(hits)
                total_count += len(data)

                for i, p in enumerate(pred_val):
                    pred_list.append(p[0])
                    label_list.append(labels[i][0])

                if step % FLAGS.log_interval == 0:
                    duration = time.time() - start_time
                    sec_per_batch = duration / FLAGS.log_interval
                    examples_per_sec = FLAGS.batch_size / sec_per_batch
                    logging(
                        '%s: [%d batches out of %d] (%.1f examples/sec; %.3f'
                        'sec/batch)' % (datetime.now(), step, total_batch_size,
                                        examples_per_sec, sec_per_batch),
                        FLAGS)
                    start_time = time.time()

            # micro precision
            # logging("%s: micro-precision = %.5f" %
            #       (datetime.now(), (hit_count/total_count)), FLAGS)
            auc_val = roc_auc_score(label_list, pred_list)
            logging(
                "%s: micro-precision = %.5f, auc = %.5f" %
                (datetime.now(), (hit_count / total_count), auc_val), FLAGS)

    if FLAGS.save_prediction:
        with open(FLAGS.save_prediction, 'w') as fout:
            with open(FLAGS.test_file) as f:
                for i, line in enumerate(f.readlines()):
                    line = line.strip()
                    id = line.split('\t')[2]

                    label = label_list[i]
                    predict = pred_list[i]

                    print('\t'.join([id, predict, label]), file=fout)
Exemplo n.º 9
0
def test( FLAGS ):
  # read data
  dataset = DataSet( fpath = FLAGS.test_file, 
                      seqlen = FLAGS.seq_len,
                      n_classes = FLAGS.num_classes,
                      need_shuffle = False )

  FLAGS.charset_size = dataset.charset_size

  with tf.Graph().as_default():
    # placeholder
    placeholders = get_placeholders(FLAGS)
    
    # get inference
    pred, layers = inference( placeholders['data'], FLAGS, 
                      for_training=False )

    # calculate prediction
    label_op = tf.argmax(pred, 1)
    prob_op = tf.nn.softmax(pred)
    # _hit_op = tf.equal( tf.argmax(pred, 1), tf.argmax(placeholders['labels'], 1))
    # hit_op = tf.reduce_sum( tf.cast( _hit_op ,tf.float32 ) )

    # create saver
    saver = tf.train.Saver()

    # argmax of hidden1
    h1_argmax_ops = []
    for op in layers['conv']:
      h1_argmax_ops.append(tf.argmax(op, axis=2))


    with tf.Session() as sess:
      # load model
      # ckpt = tf.train.latest_checkpoint( os.path.dirname( FLAGS.checkpoint_path ) )
      ckpt = FLAGS.checkpoint_path
      if tf.train.checkpoint_exists( ckpt ):
        saver.restore( sess, ckpt )
        global_step = ckpt.split('/')[-1].split('-')[-1]
        print('Succesfully loaded model from %s at step=%s.' %
              (ckpt, global_step))
      else:
        print("[ERROR] Checkpoint not exist")
        return


      # iter batch
      hit_count = 0.0
      total_count = 0

      wlens = FLAGS.window_lengths
      hsizes = FLAGS.num_windows
      motif_matches = (defaultdict(list), defaultdict(list))
      pred_labels = []
      pred_prob = []

      print("%s: starting test." % (datetime.now()))
      start_time = time.time()
      total_batch_size = math.ceil( dataset._num_data / FLAGS.batch_size )

      for step, (data, labels, raws) in enumerate(dataset.iter_once( FLAGS.batch_size, with_raw=True )):
        res_run = sess.run( [label_op, prob_op, h1_argmax_ops] + layers['conv'], feed_dict={
          placeholders['data']: data,
          placeholders['labels']: labels
        })

        pred_label = res_run[0]
        pred_arr = res_run[1]
        max_idxs = res_run[2] # shape = (wlens, N, 1, # of filters)
        motif_filters = res_run[3:]

        for i, l in enumerate(pred_label):
          pred_labels.append(l)
          pred_prob.append( pred_arr[i][ 388 ] )

        # mf.shape = (N, 1, l-w+1, # of filters)
        for i in range(len(motif_filters)):
          s = motif_filters[i].shape
          motif_filters[i] = np.transpose( motif_filters[i], (0, 1, 3, 2) ).reshape( (s[0], s[3], s[2]) )

        # mf.shape = (N, # of filters, l-w+1)
        for gidx, mf in enumerate(motif_filters):
          wlen = wlens[gidx]
          hsize = hsizes[gidx]
          for ridx, row in enumerate(mf):
            for fidx, vals in enumerate(row):
              # for each filter, get max value and it's index
              max_idx = max_idxs[gidx][ridx][0][fidx]
              # max_idx = np.argmax(vals)
              max_val = vals[ max_idx ]

              hidx = gidx * hsize + fidx

              if max_val > 0:
                # get sequence
                rawseq = raws[ridx][1]
                subseq = rawseq[ max_idx : max_idx+wlen ]
                # heappush( top_matches[hidx], (max_val, subseq) )
                motif_matches[0][hidx].append( max_val )
                motif_matches[1][hidx].append( subseq )
                # motif_matches[gidx][fidx][0].append( max_val )
                # motif_matches[gidx][fidx][1].append( subseq )


        # hit_count += np.sum( hits )
        total_count += len( data )
        # print("total:%d" % total_count)

        if step % FLAGS.log_interval == 0:
          duration = time.time() - start_time
          sec_per_batch = duration / FLAGS.log_interval
          examples_per_sec = FLAGS.batch_size / sec_per_batch
          print('%s: [%d batches out of %d] (%.1f examples/sec; %.3f'
                'sec/batch)' % (datetime.now(), step, total_batch_size,
                                examples_per_sec, sec_per_batch))
          start_time = time.time()

        # if step > 10:
        #   break


      # # micro precision
      # print("%s: micro-precision = %.5f" % 
      #       (datetime.now(), (hit_count/total_count)))
      
      print(pred_labels)
        
      ### sort top lists
      # report whose activation was higher
      mean_acts = {}
      on_acts = {}
      print('%s: write result to file' % (datetime.now()) )
      for fidx in motif_matches[0]:
        val_lst = motif_matches[0][fidx]
        seq_lst = motif_matches[1][fidx]
        # top k
        # k = wlens[ int(fidx / hsize) ] * 25
        k = 30
        # l = min(k, len(val_lst)) * -1
        l = len(val_lst) * -1
        tidxs = np.argpartition(val_lst, l)[l:]
        # tracking acts
        acts = 0.0

        opath = os.path.join(FLAGS.motif_outpath, "p%d.txt"%fidx)
        with open(opath, 'w') as fw:
          for idx in tidxs:
            fw.write("%f\t%s\n" % (val_lst[idx], seq_lst[idx]) )
            acts += val_lst[idx]

        mean_acts[fidx] = acts / l * -1
        on_acts[fidx] = len(val_lst)

        if fidx % 50 == 0:
          print('%s: [%d filters out of %d]' % (datetime.now(), fidx, sum(FLAGS.num_windows)))
          # print(len(val_lst))

      # report mean acts
      with open(os.path.join(FLAGS.motif_outpath, "report.txt"), 'w') as fw:
        for i in sorted(on_acts, key=on_acts.get, reverse=True):
          fw.write("%f\t%f\t%d\n" % (on_acts[i] / total_count, mean_acts[i], i))

      with open(os.path.join(FLAGS.motif_outpath, "predictions.txt"), 'w') as fw:
        for i, p in enumerate(pred_labels):
          fw.write("%s\t%f\n" % (str(p), pred_prob[i]))
Exemplo n.º 10
0
def train(FLAGS):
    # read data
    dataset = DataSet(fpath=FLAGS.train_file,
                      seqlen=FLAGS.seq_len,
                      n_classes=FLAGS.num_classes,
                      need_shuffle=True)
    # set character set size
    FLAGS.charset_size = dataset.charset_size

    with tf.Graph().as_default():
        # get placeholders
        global_step = tf.placeholder(tf.int32)
        placeholders = get_placeholders(FLAGS)

        # prediction
        pred, layers = inference(placeholders['data'],
                                 FLAGS,
                                 for_training=True)
        # loss
        # slim.losses.softmax_cross_entropy(pred, placeholders['labels'])
        # class_weight = tf.constant([[1.0, 5.0]])
        # weight_per_label = tf.transpose( tf.matmul(placeholders['labels']
        #                        , tf.transpose(class_weight)) )
        # loss = tf.multiply(weight_per_label,
        #         tf.nn.softmax_cross_entropy_with_logits(labels=placeholders['labels'], logits=pred))
        # loss = tf.losses.compute_weighted_loss(loss)

        tf.losses.softmax_cross_entropy(placeholders['labels'], pred)
        loss = tf.losses.get_total_loss()

        # accuracy
        _acc_op = tf.equal(tf.argmax(pred, 1),
                           tf.argmax(placeholders['labels'], 1))
        acc_op = tf.reduce_mean(tf.cast(_acc_op, tf.float32))

        # optimization
        train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)
        # train_op = tf.train.RMSPropOptimizer( FLAGS.learning_rate ).minimize( loss )

        # Create a saver.
        saver = tf.train.Saver(max_to_keep=None)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            if tf.train.checkpoint_exists(FLAGS.prev_checkpoint_path):
                if FLAGS.fine_tuning:
                    logging('%s: Fine Tuning Experiment!' % (datetime.now()),
                            FLAGS)
                    restore_variables = slim.get_variables_to_restore(
                        exclude=FLAGS.fine_tuning_layers)
                    restorer = tf.train.Saver(restore_variables)
                else:
                    restorer = tf.train.Saver()
                restorer.restore(sess, FLAGS.prev_checkpoint_path)
                logging(
                    '%s: Pre-trained model restored from %s' %
                    (datetime.now(), FLAGS.prev_checkpoint_path), FLAGS)
                step = int(
                    FLAGS.prev_checkpoint_path.split('/')[-1].split('-')
                    [-1]) + 1
            else:
                step = 0

            # iter epoch
            # for data, labels in dataset.iter_batch( FLAGS.batch_size, 5 ):
            for data, labels in dataset.iter_once(FLAGS.batch_size):
                start_time = time.time()
                _, loss_val, acc_val = sess.run(
                    [train_op, loss, acc_op],
                    feed_dict={
                        placeholders['data']: data,
                        placeholders['labels']: labels,
                        global_step: step
                    })
                duration = time.time() - start_time

                assert not np.isnan(loss_val), 'Model diverge'

                # logging
                if step > 0 and step % FLAGS.log_interval == 0:
                    examples_per_sec = FLAGS.batch_size / float(duration)
                    format_str = (
                        '%s: step %d, loss = %.2f, acc = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    logging(
                        format_str % (datetime.now(), step, loss_val, acc_val,
                                      examples_per_sec, duration), FLAGS)

                # save model
                if step > 0 and step % FLAGS.save_interval == 0:
                    saver.save(sess, FLAGS.checkpoint_path, global_step=step)

                # counter
                step += 1

            # save for last
            saver.save(sess, FLAGS.checkpoint_path, global_step=step - 1)