Exemplo n.º 1
0
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", dest="model", type=str, required=True)
    parser.add_argument("-q", "--query", dest="query", type=str, required=True)
    parser.add_argument("-c",
                        "--cutoff",
                        dest="cutoff",
                        type=float,
                        nargs="+",
                        default=[0.95])
    parser.add_argument("-o",
                        "--output",
                        dest="output",
                        type=str,
                        required=True)
    parser.add_argument("-n",
                        "--normalize",
                        dest="normalize",
                        default=False,
                        action="store_true")
    parser.add_argument("-d",
                        "--device",
                        dest="device",
                        type=str,
                        default=None)
    parser.add_argument("--clean", dest="clean", type=str, default=None)
    cmd_args = parser.parse_args()
    cmd_args.output_path = os.path.dirname(cmd_args.output)
    os.makedirs(cmd_args.output_path, exist_ok=True)
    os.environ[
        "CUDA_VISIBLE_DEVICES"] = cmd_args.device or utils.pick_gpu_lowest_memory(
        )
    return cmd_args
Exemplo n.º 2
0
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", dest="input", type=str, required=True)
    parser.add_argument("-g", "--genes", dest="genes", type=str, default=None)
    parser.add_argument("-b", "--batch-effect", dest="batch_effect", type=str, default=None)
    parser.add_argument("-o", "--output", dest="output", type=str, required=True)

    parser.add_argument("--n-latent", dest="n_latent", type=int, default=10)
    parser.add_argument("--n-hidden", dest="n_hidden", type=int, default=128)
    parser.add_argument("--n-layers", dest="n_layers", type=int, default=1)

    parser.add_argument("--supervision", dest="supervision", type=str, default=None)
    parser.add_argument("--label-fraction", dest="label_fraction", type=float, default=None)
    parser.add_argument("--label-priority", dest="label_priority", type=str, default=None)

    parser.add_argument("--n-epochs", dest="n_epochs", type=int, default=1000)
    parser.add_argument("--patience", dest="patience", type=int, default=30)
    parser.add_argument("--learning-rate", dest="lr", type=float, default=1e-3)

    parser.add_argument("-s", "--seed", dest="seed", type=int, default=None)
    parser.add_argument("-d", "--device", dest="device", type=str, default=None)
    parser.add_argument("--clean", dest="clean", type=str, default=None)
    cmd_args = parser.parse_args()
    cmd_args.output_path = os.path.dirname(cmd_args.output)
    if not os.path.exists(cmd_args.output_path):
        os.makedirs(cmd_args.output_path)
    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    if cmd_args.seed is not None:
        np.random.seed(cmd_args.seed)
        torch.manual_seed(cmd_args.seed)
    return cmd_args
Exemplo n.º 3
0
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", dest="input", type=str, required=True)
    parser.add_argument("-o",
                        "--output",
                        dest="output",
                        type=str,
                        required=True)
    parser.add_argument("-m", "--model", dest="model", type=str, required=True)
    parser.add_argument("-b",
                        "--batch-effect",
                        dest="batch_effect",
                        type=str,
                        required=True)
    parser.add_argument("-t",
                        "--target",
                        dest="target",
                        type=str,
                        choices=["zeros", "first", "ones"],
                        default="zeros")
    parser.add_argument("-d",
                        "--device",
                        dest="device",
                        type=str,
                        default=None)
    parser.add_argument("--clean", dest="clean", type=str, default=None)
    cmd_args = parser.parse_args()
    os.makedirs(os.path.dirname(cmd_args.output), exist_ok=True)
    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    return cmd_args
Exemplo n.º 4
0
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", dest="input", type=str, required=True)
    parser.add_argument("-g", "--genes", dest="genes", type=str, required=True)
    parser.add_argument("-o", "--output", dest="output", type=str, required=True)

    parser.add_argument("--n-latent", dest="n_latent", type=int, default=32)
    parser.add_argument("--n-hidden", dest="n_hidden", type=int, default=64)
    parser.add_argument("--n-layers", dest="n_layers", type=int, default=1)

    parser.add_argument("--n-epochs", dest="n_epochs", type=int, default=1000)
    parser.add_argument("--patience", dest="patience", type=int, default=30)

    parser.add_argument("-s", "--seed", dest="seed", type=int, default=None)  # Not exactly be reproducible though
    parser.add_argument("-t", "--threads", dest="threads", type=int, default=None)
    parser.add_argument("-d", "--device", dest="device", type=str, default=None)
    parser.add_argument("--clean", dest="clean", type=str, default=None)

    cmd_args = parser.parse_args()
    cmd_args.output_path = os.path.dirname(cmd_args.output)
    if not os.path.exists(cmd_args.output_path):
        os.makedirs(cmd_args.output_path)
    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    return cmd_args
Exemplo n.º 5
0
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", dest="input", type=str, required=True)
    parser.add_argument("-g", "--genes", dest="genes", type=str, default=None)
    parser.add_argument("-b",
                        "--batch-effect",
                        dest="batch_effect",
                        type=str,
                        default=None)
    parser.add_argument("-o",
                        "--output",
                        dest="output",
                        type=str,
                        required=True)

    parser.add_argument("-s", "--seed", dest="seed", type=int, default=None)
    parser.add_argument("-d",
                        "--device",
                        dest="device",
                        type=str,
                        default=None)
    parser.add_argument("--clean", dest="clean", type=str, default=None)
    cmd_args = parser.parse_args()
    cmd_args.output_path = os.path.dirname(cmd_args.output)
    if not os.path.exists(cmd_args.output_path):
        os.makedirs(cmd_args.output_path)
    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    return cmd_args
Exemplo n.º 6
0
def main(cmd_args):

    cb.message.info("Loading index...")
    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    blast = cb.blast.BLAST.load(cmd_args.index)
    if cmd_args.subsample_ref is not None:
        cb.message.info("Subsampling reference...")
        subsample_idx = np.random.RandomState(cmd_args.seed).choice(
            blast.ref.shape[0], cmd_args.subsample_ref, replace=False)
        blast.ref = blast.ref[subsample_idx, :]
        blast.latent = blast.latent[
            subsample_idx] if blast.latent is not None else None
        blast.cluster = blast.cluster[
            subsample_idx] if blast.cluster is not None else None
        blast.posterior = blast.posterior[
            subsample_idx] if blast.posterior is not None else None
        blast.nearest_neighbors = None
        blast.empirical = None
        blast._force_components()

    cb.message.info("Reading query...")
    query = cb.data.ExprDataSet.read_dataset(cmd_args.query)
    if cmd_args.clean:
        query = utils.clean_dataset(query, cmd_args.clean)

    if cmd_args.align:
        cb.message.info("Aligning...")
        unipath = "/tmp/cb/" + cb.utils.rand_hex()
        cb.message.info("Using temporary path: " + unipath)
        blast = blast.align(query, path=unipath)

    cb.message.info("BLASTing...")
    start_time = time.time()
    hits = blast.query(query,
                       n_neighbors=cmd_args.n_neighbors).reconcile_models()

    time_per_cell = None
    prediction_dict = {}
    for cutoff in cmd_args.cutoff:
        prediction_dict[cutoff] = hits.filter(
            by=cmd_args.filter_by, cutoff=cutoff).annotate(
                cmd_args.annotation,
                min_hits=cmd_args.min_hits)[cmd_args.annotation]
        if time_per_cell is None:
            time_per_cell = (time.time() - start_time) * 1000 / len(
                prediction_dict[cutoff])
    print("Time per cell: %.3fms" % time_per_cell)

    cb.message.info("Saving result...")
    if os.path.exists(cmd_args.output):
        os.remove(cmd_args.output)
    for cutoff in prediction_dict:
        cb.data.write_hybrid_path(
            prediction_dict[cutoff],
            "%s//prediction/%s" % (cmd_args.output, str(cutoff)))
    cb.data.write_hybrid_path(time_per_cell, "//".join(
        (cmd_args.output, "time")))
Exemplo n.º 7
0
def main(cmd_args):

    cb.message.info("Reading data...")
    dataset = cb.data.ExprDataSet.read_dataset(cmd_args.ref)
    if cmd_args.clean:
        dataset = utils.clean_dataset(dataset, cmd_args.clean)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(utils.pick_gpu_lowest_memory()) \
        if cmd_args.device is None else cmd_args.device
    models = [cb.directi.DIRECTi.load(model) for model in cmd_args.models]

    cb.message.info("Building Cell BLAST index...")
    blast = cb.blast.BLAST(models, dataset, n_posterior=cmd_args.n_posterior)

    cb.message.info("Saving index...")
    blast.save(cmd_args.output_path)

    cb.message.info("Done!")
Exemplo n.º 8
0
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", dest="model", type=str, required=True)
    parser.add_argument("-r", "--ref", dest="ref", type=str, required=True)
    parser.add_argument("-q", "--query", dest="query", type=str, required=True)
    parser.add_argument("-o",
                        "--output",
                        dest="output",
                        type=str,
                        required=True)
    parser.add_argument("-a",
                        "--annotation",
                        dest="annotation",
                        type=str,
                        default="cell_ontology_class")
    parser.add_argument("--n-neighbors",
                        dest="n_neighbors",
                        type=int,
                        default=10)
    parser.add_argument("--min-hits", dest="min_hits", type=int, default=2)
    parser.add_argument("-c",
                        "--cutoff",
                        dest="cutoff",
                        type=float,
                        nargs="+",
                        default=[0.1])
    parser.add_argument("-s", "--seed", dest="seed", type=int, default=None)
    parser.add_argument("-d",
                        "--device",
                        dest="device",
                        type=str,
                        default=None)
    parser.add_argument("--subsample-ref",
                        dest="subsample_ref",
                        type=int,
                        default=None)
    parser.add_argument("--clean", dest="clean", type=str, default=None)
    cmd_args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    keras.backend.set_session(tf.Session(config=config))
    return cmd_args
Exemplo n.º 9
0
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", dest="input", type=str, required=True)
    parser.add_argument("-g", "--genes", dest="genes", type=str, required=True)
    parser.add_argument("-o",
                        "--output",
                        dest="output",
                        type=str,
                        required=True)

    parser.add_argument("--n-latent", dest="n_latent", type=int, default=3)
    parser.add_argument("--n-epochs", dest="n_epochs", type=int, default=100)
    # Reducing epoch number to the maximum of author recommendation,
    # because Dhaka does not support early stopping and we see
    # numerical instability with larger number of epochs.

    parser.add_argument("-s", "--seed", dest="seed", type=int,
                        default=None)  # Not exactly be reproducible though
    parser.add_argument("-d",
                        "--device",
                        dest="device",
                        type=str,
                        default=None)
    parser.add_argument("--clean", dest="clean", type=str, default=None)
    cmd_args = parser.parse_args()
    cmd_args.output_path = os.path.dirname(cmd_args.output)
    if not os.path.exists(cmd_args.output_path):
        os.makedirs(cmd_args.output_path)
    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    if cmd_args.seed is not None:
        np.random.seed(cmd_args.seed)
        random.seed(cmd_args.seed)
        tf.set_random_seed(cmd_args.seed)
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(graph=tf.get_default_graph(), config=tf_config)
    K.set_session(sess)
    return cmd_args
Exemplo n.º 10
0
import tensorflow as tf
import numpy as np
import utils
import pickle
import time
import argparse
import os
from models.unrolledplan.model import IMP
# only for GPU instances, comment it out otherwise
os.environ["CUDA_VISIBLE_DEVICES"] = str(utils.pick_gpu_lowest_memory())

def parse_args():

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--inner-horizon', type=int, default=5, help='length of RNN rollout horizon')
    parser.add_argument('--outer-horizon', type=int, default=5, help='length of BC loss horizon')
    parser.add_argument('--num-plan-updates', type=int, default=8, help='number of planning update steps before BC loss')
    parser.add_argument('--n-hidden', type=int, default=1, help='number of hidden layers to encode after conv')
    parser.add_argument('--obs-latent-dim', type=int, default=128, help='obs latent space dim')
    parser.add_argument('--act-latent-dim', type=int, default=128, help='act latent space dim')
    parser.add_argument('--meta-gradient-clip-value', type=float, default=25., help='meta gradient clip value')
    parser.add_argument('--batch-size', type=int, default=128, help='batch size')
    parser.add_argument('--test-batch-size', type=int, default=128, help='test batch size')
    parser.add_argument('--il-lr-0', type=float, default=0.5, help='il_lr_0')
    parser.add_argument('--il-lr', type=float, default=0.25, help='il_lr')
    parser.add_argument('--ol-lr', type=float, default=0.0035, help='ol_lr')
    parser.add_argument('--num-batch-updates', type=int, default=100000, help='number of minibatch updates')
    parser.add_argument('--testing-frequency', type=int, default=2000, help='how frequently to get stats for test data')
    parser.add_argument('--log-file', type=str, default='log', help='name of log file to dump test data stats')
    parser.add_argument('--log-directory', type=str, default='log', help='name of log directory to dump checkpoints')
    parser.add_argument('--huber', dest='huber_loss', action='store_true', help='whether to use Huber Loss')
Exemplo n.º 11
0
def train_network(n,
                  tr_data,
                  tr_labels,
                  ts_data,
                  ts_labels,
                  num_epochs=20,
                  batch_size=1024):

    gpu_num = str(utils.pick_gpu_lowest_memory())
    with tf.device('/device:GPU:' + gpu_num):
        sample_size = tr_data.shape[0]
        if batch_size is None:
            # sample_size itself will be the batch_size.
            batch_size = sample_size

        num_batches = math.ceil(sample_size / batch_size)
        total_iterations = num_epochs * num_batches
        last_progress_in_percent = 0

        # Train network: loss and train_step update

        # config = tf.ConfigProto(device_count = {'GPU': 2})
        # with tf.Session(config=config) as sess:
        with tf.Session() as sess:

            sess.run(n.init_op)  # variables initializer
            # old_weights = [w.eval() for w in n.weights] # a list of weights.

            for epoch in range(num_epochs):  # per epoch
                for batch in range(num_batches):  # per batch

                    # batch_start and batch_end index
                    batch_start = batch * batch_size
                    batch_end = batch_start + batch_size

                    # run train_step and loss
                    # you can train and compute loss at the same time like this:
                    _, loss_val = sess.run(
                        (n.train_step, n.loss),
                        feed_dict={
                            n.x: tr_data[batch_start:batch_end],
                            n.y_: tr_labels[batch_start:batch_end],
                        })
                    new_weights = [w.eval() for w in n.weights]

                    # Print status at certain intervals
                    iteration = epoch * num_batches + batch
                    progress = (iteration + 1) / total_iterations
                    progress_in_percent = math.floor(progress * 100)

                    if iteration == 0 or last_progress_in_percent != progress_in_percent:
                        # Calculate change in weights after each batch
                        '''
                        weight_change = np.asarray([np.mean(np.abs(old - new))
                                                    for old, new in zip(old_weights, new_weights)])
                        '''
                        # Calculate training and testing accuracy
                        training_accuracy = n.accuracy.eval(feed_dict={
                            n.x: tr_data,
                            n.y_: tr_labels
                        })
                        testing_accuracy = n.accuracy.eval(feed_dict={
                            n.x: ts_data,
                            n.y_: ts_labels
                        })
                        # Print status
                        print(
                            "\r{:4d}/{} [{}] ({:3.0f}%), L: {:f}, ATrain: {:4.1f}%, ATest: {:4.1f}%"
                            .format(epoch + 1, num_epochs,
                                    get_progress_arrow(20, progress),
                                    progress * 100, loss_val,
                                    training_accuracy * 100,
                                    testing_accuracy * 100),
                            end="")

                    # progress and weights update:
                    last_progress_in_percent = progress_in_percent
                    old_weights = new_weights

                    # every 5 epoch print an additional new line.
                    if batch == 0 and epoch % 5 == 0:
                        # After the first batch of some epochs print a new line to keep a history in console
                        print("")

            print("")  # New line

            # check the accuracy after training.
            training_accuracy = n.accuracy.eval(feed_dict={
                n.x: tr_data,
                n.y_: tr_labels
            })
            testing_accuracy = n.accuracy.eval(feed_dict={
                n.x: ts_data,
                n.y_: ts_labels
            })
            print(
                "Finished training with {:4.1f}% training and {:4.1f}% testing accuracy"
                .format(training_accuracy * 100, testing_accuracy * 100))

            # Return training and testing error rate
            return 1 - training_accuracy, 1 - testing_accuracy
Exemplo n.º 12
0
def main(cmd_args):
    cb.message.info("Reading data...")
    dataset = cb.data.ExprDataSet.read_dataset(cmd_args.input)
    if not cmd_args.no_normalize:
        dataset = dataset.normalize()
    if cmd_args.clean:
        dataset = utils.clean_dataset(dataset, cmd_args.clean)

    if cmd_args.supervision is not None and cmd_args.label_fraction is not None:
        label = dataset.obs[cmd_args.supervision]
        if cmd_args.label_priority is not None:
            label_priority = dataset.obs[cmd_args.label_priority].values
        else:
            _label_priority = np.random.uniform(size=label.shape[0])
            label_priority = np.empty(len(_label_priority))
            for l in np.unique(label):  # Group percentile
                mask = label == l
                label_priority[mask] = (scipy.stats.rankdata(
                    _label_priority[mask]) - 1) / (mask.sum() - 1)
        exclude_mask = label_priority < np.percentile(
            label_priority, (1 - cmd_args.label_fraction) * 100)
        dataset.obs.loc[exclude_mask, cmd_args.supervision] = np.nan

    latent_module_kwargs = dict(lambda_reg=cmd_args.lambda_prior_reg)
    if cmd_args.supervision is not None:
        latent_module_kwargs["lambda_sup"] = cmd_args.lambda_sup
    prob_module_kwargs = dict(lambda_reg=cmd_args.lambda_prob_reg)
    rmbatch_module_kwargs = dict(lambda_reg=cmd_args.lambda_rmbatch_reg)

    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    start_time = time.time()
    model = cb.directi.fit_DIRECTi(
        dataset,
        genes=None if cmd_args.genes is None else dataset.uns[cmd_args.genes],
        latent_dim=cmd_args.latent_dim,
        cat_dim=cmd_args.cat_dim,
        supervision=cmd_args.supervision,
        batch_effect=cmd_args.batch_effect,
        h_dim=cmd_args.h_dim,
        depth=cmd_args.depth,
        prob_module=cmd_args.prob_module,
        rmbatch_module=cmd_args.rmbatch_module,
        latent_module_kwargs=latent_module_kwargs,
        prob_module_kwargs=prob_module_kwargs,
        rmbatch_module_kwargs=rmbatch_module_kwargs,
        optimizer=cmd_args.optimizer,
        learning_rate=cmd_args.learning_rate,
        batch_size=cmd_args.batch_size,
        val_split=cmd_args.val_split,
        epoch=cmd_args.epoch,
        patience=cmd_args.patience,
        progress_bar=True,
        random_seed=cmd_args.seed,
        path=cmd_args.output_path)
    model.save()

    cb.message.info("Saving results...")
    inferred_latent = model.inference(dataset)
    cb.data.write_hybrid_path(time.time() - start_time,
                              "%s//time" % cmd_args.output)
    if "exclude_mask" in globals():
        cb.data.write_hybrid_path(~exclude_mask,
                                  "%s//supervision" % cmd_args.output)
    cb.data.write_hybrid_path(inferred_latent, "%s//latent" % cmd_args.output)
    try:  # If intrinsic clustering is used
        cb.data.write_hybrid_path(
            model.clustering(dataset)[0], "%s//cluster" % cmd_args.output)
    except Exception:
        pass
Exemplo n.º 13
0
import os, sys
import tensorflow as tf
sys.path.append("../")
import utils
print("GPU with lowest memory %s" % str(utils.pick_gpu_lowest_memory()))
os.environ["CUDA_VISIBLE_DEVICES"] = str(utils.pick_gpu_lowest_memory())
print((tf.__version__))
config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
print(session)

import h5py
import numpy

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--input", help="input hdf5 file", type=str)
parser.add_argument("--tag", help="save name for weights file", type=str)
args = parser.parse_args()

f = h5py.File(args.input, "r")

object_features, object_features_validation, object_features_data = f[
    'object'], f['object_validation'], f['object_data']
global_features, global_features_validation, global_features_data = f[
    'global'], f['global_validation'], f['global_data']
label, label_validation, label_data = f['label'], f['label_validation'], f[
    'label_data']
process_id, process_id_validation, process_id_data = f['process_id'], f[
    'process_id_validation'], f['process_id_data']
Exemplo n.º 14
0
                    help='model to be attacked')
parser.add_argument('--seed', type=int, default=1, help='random seed')
parser.add_argument('--query', type=int, help='Query limit allowed')
parser.add_argument('--save', type=str, default='', help='exp_id')
parser.add_argument('--exp_tag', type=str, default='')
parser.add_argument('--gpu',
                    type=str,
                    default='auto',
                    help='tag for saving, enter debug mode if debug is in it')

args = parser.parse_args()

#### env
np.random.seed(args.seed)
torch.manual_seed(args.seed)
gpu = utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
torch.cuda.set_device(gpu)
print('gpu:', gpu)

#### macros
attack_list = {
    "PGD": PGD,
    "Sign_OPT": OPT_attack_sign_SGD,
    "Sign_OPT_lf": OPT_attack_sign_SGD_lf,
    "CW": CW,
    "OPT_attack": OPT_attack,
    "HSJA": HSJA,
    "OPT_attack_lf": OPT_attack_lf,
    "FGSM": FGSM,
    "NES": NES,
    "Bandit": Bandit,
Exemplo n.º 15
0
def main(cmd_args):

    cb.message.info("Reading data...")
    genes = np.loadtxt(os.path.join(cmd_args.model, "genes.txt"), dtype=np.str)
    ref = cb.data.ExprDataSet.read_dataset(cmd_args.ref)
    ref = utils.clean_dataset(
        ref,
        cmd_args.clean).to_anndata() if cmd_args.clean else ref.to_anndata()
    ref = ref[np.random.RandomState(cmd_args.seed).
              choice(ref.shape[0], cmd_args.subsample_ref, replace=False
                     ), :] if cmd_args.subsample_ref is not None else ref
    ref_label = ref.obs[cmd_args.annotation].values
    ref = dca_modpp.io.normalize(ref,
                                 genes,
                                 filter_min_counts=False,
                                 size_factors=10000,
                                 normalize_input=False,
                                 logtrans_input=True)
    cb.message.info("Loading model...")
    os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory() \
        if cmd_args.device is None else cmd_args.device
    model = keras.models.load_model(os.path.join(cmd_args.model, "model.h5"))

    cb.message.info("Projecting to latent space...")
    ref_latent = model.predict({
        "count": ref.X,
        "size_factors": ref.obs.size_factors
    })
    nn = sklearn.neighbors.NearestNeighbors().fit(ref_latent)

    cb.message.info("Building empirical distribution...")
    np.random.seed(cmd_args.seed)
    idx1 = np.random.choice(ref_latent.shape[0], size=N_EMPIRICAL)
    idx2 = np.random.choice(ref_latent.shape[0], size=N_EMPIRICAL)
    empirical = np.sort(
        np.sqrt(np.sum(np.square(ref_latent[idx1] - ref_latent[idx2]),
                       axis=1)))

    cb.message.info("Querying...")
    query = cb.data.ExprDataSet.read_dataset(cmd_args.query)
    query = query[:, np.union1d(query.var_names, genes)]
    query = utils.clean_dataset(
        query,
        cmd_args.clean).to_anndata() if cmd_args.clean else query.to_anndata()
    start_time = time.time()
    query = dca_modpp.io.normalize(query,
                                   genes,
                                   filter_min_counts=False,
                                   size_factors=10000,
                                   normalize_input=False,
                                   logtrans_input=True)
    query_latent = model.predict({
        "count": query.X,
        "size_factors": query.obs.size_factors
    })
    nnd, nni = nn.kneighbors(query_latent, n_neighbors=cmd_args.n_neighbors)
    pval = np.empty_like(nnd, np.float32)
    time_per_cell = None
    prediction_dict = collections.defaultdict(list)

    for cutoff in cmd_args.cutoff:
        for i in range(nnd.shape[0]):
            for j in range(nnd.shape[1]):
                pval[i, j] = np.searchsorted(empirical,
                                             nnd[i, j]) / empirical.size
            uni, count = np.unique(ref_label[nni[i][pval[i] < cutoff]],
                                   return_counts=True)
            total_count = count.sum()
            if total_count < cmd_args.min_hits:
                prediction_dict[cutoff].append("rejected")
                continue
            argmax = np.argmax(count)
            if count[argmax] / total_count <= MAJORITY_THRESHOLD:
                prediction_dict[cutoff].append("ambiguous")
                continue
            prediction_dict[cutoff].append(uni[argmax])
        prediction_dict[cutoff] = np.array(prediction_dict[cutoff])
        if time_per_cell is None:
            time_per_cell = (time.time() - start_time) * 1000 / len(
                prediction_dict[cutoff])
    print("Time per cell: %.3fms" % time_per_cell)

    cb.message.info("Saving results...")
    if os.path.exists(cmd_args.output):
        os.remove(cmd_args.output)
    for cutoff in prediction_dict:
        cb.data.write_hybrid_path(
            prediction_dict[cutoff],
            "%s//prediction/%s" % (cmd_args.output, str(cutoff)))
    cb.data.write_hybrid_path(nni, "//".join((cmd_args.output, "nni")))
    cb.data.write_hybrid_path(nnd, "//".join((cmd_args.output, "nnd")))
    cb.data.write_hybrid_path(pval, "//".join((cmd_args.output, "pval")))
    cb.data.write_hybrid_path(time_per_cell, "//".join(
        (cmd_args.output, "time")))