コード例 #1
0
    is_parallel = False
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    problem = None
    if args.problem.startswith("genotyping:"):
        problem = SbiGenotypingProblem(args.mini_batch_size,
                                       code=args.problem,
                                       num_workers=args.num_workers)
    elif args.problem.startswith("struct_genotyping:"):
        # struct_genotyping does not support multiprocessing data loading:
        problem = StructuredSbiGenotypingProblem(args.mini_batch_size,
                                                 code=args.problem,
                                                 num_workers=1)
    elif args.problem.startswith("somatic:"):
        problem = SbiSomaticProblem(args.mini_batch_size,
                                    code=args.problem,
                                    num_workers=args.num_workers)
    else:
        print("Unsupported problem: " + args.problem)
        exit(1)

    def get_metric_value(all_perfs, query_metric_name):
        for perf in all_perfs:
            metric = perf.get_metric(query_metric_name)
            if metric is not None:
                return metric

    def train_once(train_args, train_problem, train_device, train_use_cuda):
        train_problem.describe()
        training_loop_method = None
        testing_loop_method = None
コード例 #2
0
except FileNotFoundError:
    print("Unable to load model {} from checkpoint".format(
        args.checkpoint_key))
    exit(1)

if checkpoint is not None:
    model = checkpoint['model']
problem = None
if args.problem.startswith("genotyping:"):
    problem = SbiGenotypingProblem(args.mini_batch_size,
                                   code=args.problem,
                                   drop_last_batch=False,
                                   num_workers=args.num_workers)
elif args.problem.startswith("somatic:"):
    problem = SbiSomaticProblem(args.mini_batch_size,
                                code=args.problem,
                                drop_last_batch=False,
                                num_workers=args.num_workers)
elif args.problem.startswith("struct_genotyping:"):
    problem = StructuredSbiGenotypingProblem(args.mini_batch_size,
                                             code=args.problem,
                                             drop_last_batch=False,
                                             num_workers=args.num_workers)
else:
    print("Unsupported problem: " + args.problem)
    exit(1)

if problem is None or model is None:
    print("no problem or model, aborting")
    exit(1)

domain_descriptor = checkpoint[