is_parallel = False best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch problem = None if args.problem.startswith("genotyping:"): problem = SbiGenotypingProblem(args.mini_batch_size, code=args.problem, num_workers=args.num_workers) elif args.problem.startswith("struct_genotyping:"): # struct_genotyping does not support multiprocessing data loading: problem = StructuredSbiGenotypingProblem(args.mini_batch_size, code=args.problem, num_workers=1) elif args.problem.startswith("somatic:"): problem = SbiSomaticProblem(args.mini_batch_size, code=args.problem, num_workers=args.num_workers) else: print("Unsupported problem: " + args.problem) exit(1) def get_metric_value(all_perfs, query_metric_name): for perf in all_perfs: metric = perf.get_metric(query_metric_name) if metric is not None: return metric def train_once(train_args, train_problem, train_device, train_use_cuda): train_problem.describe() training_loop_method = None testing_loop_method = None
except FileNotFoundError: print("Unable to load model {} from checkpoint".format( args.checkpoint_key)) exit(1) if checkpoint is not None: model = checkpoint['model'] problem = None if args.problem.startswith("genotyping:"): problem = SbiGenotypingProblem(args.mini_batch_size, code=args.problem, drop_last_batch=False, num_workers=args.num_workers) elif args.problem.startswith("somatic:"): problem = SbiSomaticProblem(args.mini_batch_size, code=args.problem, drop_last_batch=False, num_workers=args.num_workers) elif args.problem.startswith("struct_genotyping:"): problem = StructuredSbiGenotypingProblem(args.mini_batch_size, code=args.problem, drop_last_batch=False, num_workers=args.num_workers) else: print("Unsupported problem: " + args.problem) exit(1) if problem is None or model is None: print("no problem or model, aborting") exit(1) domain_descriptor = checkpoint[