def run_predict(self, args):
        """Predict pathogenic potentials from a fasta/npy file."""
        if args.tpu is None:
            config_cpus(args.n_cpus)
            config_gpus(args.gpus)
        if args.output is None:
            args.output = os.path.splitext(args.input)[0] + "_predictions.npy"

        if args.sensitive:
            model = self.bloader.load_sensitive_model(training_mode=False, tpu_resolver=self.tpu_resolver)
        elif args.rapid:
            model = self.bloader.load_rapid_model(training_mode=False, tpu_resolver=self.tpu_resolver)
        else:
            if self.tpu_resolver is not None:
                tpu_strategy = tf.distribute.experimental.TPUStrategy(self.tpu_resolver)
                with tpu_strategy.scope():
                    model = tf.keras.models.load_model(args.custom)
            else:
                model = tf.keras.models.load_model(args.custom)

        if args.rc_check:
            compare_rc(model, args.input, args.output, args.plot_kind, args.alpha, replicates=args.replicates,
                       batch_size=args.batch_size)
        elif args.array:
            predict_npy(model, args.input, args.output, replicates=args.replicates, batch_size=args.batch_size)
        else:
            predict_fasta(model, args.input, args.output, args.n_cpus, replicates=args.replicates,
                          batch_size=args.batch_size)
    def run_train(self, args):
        """Parse the config file and train the NN on Illumina reads."""
        if args.tpu is None:
            config_cpus(args.n_cpus)
            config_gpus(args.gpus)
        if args.sensitive:
            paprconfig = self.bloader.get_sensitive_training_config()
        elif args.rapid:
            paprconfig = self.bloader.get_rapid_training_config()
        else:
            config = configparser.ConfigParser()
            config.read(args.custom)
            paprconfig = RCConfig(config)

        if args.train_data:
            paprconfig.x_train_path = args.train_data
        if args.train_labels:
            paprconfig.y_train_path = args.train_labels
        if args.val_data:
            paprconfig.x_val_path = args.val_data
        if args.val_labels:
            paprconfig.y_val_path = args.val_labels
        if args.run_name:
            paprconfig.runname = args.run_name
            paprconfig.log_dir = os.path.join(paprconfig.log_superpath,
                                              "{runname}-logs".format(runname=paprconfig.runname))

        paprconfig.set_tpu_resolver(self.tpu_resolver)
        paprnet = RCNet(paprconfig)
        paprnet.load_data()
        paprnet.compile_model()
        paprnet.train()
def run_tester(args):
    tpu_resolver = global_setup(args)
    if args.tpu is None:
        n_cpus = config_cpus(args.n_cpus_rec)
        config_gpus(args.gpus)
    else:
        n_cpus = args.n_cpus_rec
    if args.custom:
        args.command = None
    run_tests(args.command, args.model, n_cpus, args.keep, args.scale,
              tpu_resolver)
 def run_tests(self, args):
     """Run tests."""
     if args.tpu is None:
         n_cpus = config_cpus(args.n_cpus)
         n_gpus = config_gpus(args.gpus)
         scale = args.scale * max(1, n_gpus)
     else:
         n_cpus = multiprocessing.cpu_count()
         scale = args.scale
     tester = Tester(n_cpus, self.builtin_configs, self.builtin_weights,
                     args.explain, args.gwpa, args.all, args.quick, args.keep, scale,
                     tpu_resolver=self.tpu_resolver, input_modes=args.input_modes,
                     additivity_check=(not args.no_check), large=args.large)
     tester.run_tests()
def run_receiver(args):
    tpu_resolver = global_setup(args)
    if args.tpu is None:
        n_cpus = config_cpus(args.n_cpus_rec)
        config_gpus(args.gpus)
    else:
        n_cpus = args.n_cpus_rec
    if args.custom:
        args.command = None
    receiver = Receiver(args.command,
                        model=args.model,
                        read_length=args.read_length,
                        input_dir=args.rec_in_dir,
                        output_dir=args.rec_out_dir,
                        n_cpus=n_cpus,
                        threshold=args.threshold,
                        tpu_resolver=tpu_resolver)
    cycles = [int(c) for c in args.cycle_list.split(',')]
    barcodes = args.barcodes.split(',')
    receiver.run(cycles=cycles,
                 barcodes=barcodes,
                 mode=args.format,
                 discard_neg=args.discard_neg)