예제 #1
0
def test_all_benchmarks():
    all_benchmarks(n_epochs=1, save_path='tests/data/')
예제 #2
0
def test_all_benchmarks(save_path):
    all_benchmarks(n_epochs=1, save_path=save_path)
예제 #3
0
def test_all_benchmarks(save_path):
    all_benchmarks(n_epochs=1, save_path=save_path, show_plot=False)
예제 #4
0
def test_all_benchmarks():
    all_benchmarks(n_epochs=1)
예제 #5
0
def test_all_benchmarks():
    all_benchmarks(n_epochs=1, unit_test=True)
예제 #6
0
        help="whether to use cuda (will apply only if cuda is available")
    parser.add_argument(
        "--all",
        action='store_true',
        help="whether to use cuda (will apply only if cuda is available")
    parser.add_argument(
        "--benchmark",
        action='store_true',
        help="whether to use cuda (will apply only if cuda is available")
    parser.add_argument("--url",
                        type=str,
                        help="the url for downloading gene_dataset")
    args = parser.parse_args()

    n_epochs = args.epochs
    use_cuda = not args.nocuda
    if args.all:
        all_benchmarks(n_epochs=n_epochs, use_cuda=use_cuda)
    elif args.harmonization:
        harmonization_benchmarks(n_epochs=n_epochs, use_cuda=use_cuda)
    elif args.annotation:
        annotation_benchmarks(n_epochs=n_epochs, use_cuda=use_cuda)
    else:
        dataset = load_datasets(args.dataset, url=args.url)
        model = available_models[args.model](
            dataset.nb_genes, dataset.n_batches * args.nobatches,
            dataset.n_labels)
        inference_cls = VariationalInference if args.model == 'VAE' else JointSemiSupervisedVariationalInference
        infer = inference_cls(model, dataset, use_cuda=use_cuda)
        infer.train(n_epochs=n_epochs)