def runner( inhandles: Sequence[TextIO], outhandle: TextIO, pca_handle: Optional[TextIO], counts_filepath: Optional[str], labels: Sequence[str], file_format: FileType, model: Model, ): """ Runs the pipeline. """ parsed = [ parsers.parse(h, file_format, model.hmm_lengths) for h in inhandles ] required_cols = list(model.hmm_lengths.keys()) counts = cazy_counts_multi(parsed, labels, required_cols) if counts_filepath is not None: counts.write_tsv(counts_filepath) predictions = model.predict(counts) RCDResult.write_tsv(outhandle, predictions.rcd) if pca_handle is not None: (PCAWithLabels.concat([model.training_data, predictions]).write_tsv(pca_handle)) return
def test_predict(version, clss, exp_val): """ Pretty much a repeat of test_rcds. """ with open(model_filepath(version), "rb") as handle: model = Model.read(handle) required_cols = list(model.hmm_lengths.keys()) files = test_files(version) with open(files["hmmer_text"], "r") as handle: parsed = HMMER.from_file( handle, model.hmm_lengths, "hmmer3-text", ) counts = cazy_counts_multi([parsed], ["test"], required_cols) pred = model.predict(counts) rcds = pred.rcd rcds_dict = dict() for res in rcds: assert res.label == "test" rcds_dict[(res.label, res.nomenclature, res.nomenclature_class)] = res.value # Checks that results are accurate to 5 decimal places. assert_almost_equal( exp_val, rcds_dict[("test", "nomenclature3", clss)], decimal=5 ) return
def test_parse_hmmer_text_output(version, idx, col, exp_val): with open(model_filepath(version), "rb") as handle: model = Model.read(handle) files = test_files(version) with open(files["hmmer_text"], "r") as handle: sample = list( HMMER.from_file( handle, model.hmm_lengths, "hmmer3-text", )) assert getattr(sample[idx], col) == exp_val return
def test_cazy_counts(version, hmm, exp_val): with open(model_filepath(version), "rb") as handle: model = Model.read(handle) required_cols = list(model.hmm_lengths.keys()) files = test_files(version) with open(files["hmmer_text"], "r") as handle: parsed = HMMER.from_file( handle, model.hmm_lengths, "hmmer3-text", ) counts = cazy_counts(parsed, required_cols) column_index = required_cols.index(hmm) assert counts[column_index] == exp_val return
def test_pca(version, pc, exp_val): with open(model_filepath(version), "rb") as handle: model = Model.read(handle) required_cols = list(model.hmm_lengths.keys()) files = test_files(version) with open(files["hmmer_text"], "r") as handle: parsed = HMMER.from_file( handle, model.hmm_lengths, "hmmer3-text", ) counts = cazy_counts_multi([parsed], ["test"], required_cols) pred = model.predict(counts) trans = pred.pca # Checks that results are accurate to 5 decimal places. column_index = trans.columns.index(pc) assert_almost_equal(exp_val, trans.arr[0, column_index], decimal=5) return
def main(): # noqa """ The cli interface to CATAStrophy. """ try: args = cli(prog=sys.argv[0], args=sys.argv[1:]) except MyArgumentError as e: print(e.message, file=sys.stderr) sys.exit(e.errno) infile_names = [f.name for f in args.infile] if args.labels is None: labels = infile_names elif len(args.labels) != len(args.infile): msg = ("argument labels and inhandles: \n" "When specified, the number of labels must be the same as the " "number of input files. Exiting.\n") print(msg, file=sys.stderr) sys.exit(EXIT_CLI) else: labels = args.labels if args.model_file is not None: model = Model.read(args.model_file) else: with open(data.model_filepath(args.model_version), "rb") as handle: model = Model.read(handle) try: runner( args.infile, args.outhandle, args.pca_handle, args.counts_filepath, labels, args.file_format, model, ) except ParseError as e: if e.line is not None: header = "Failed to parse file <{}> at line {}.\n".format( e.filename, e.line) else: header = "Failed to parse file <{}>.\n".format(e.filename) print("{}\n{}".format(header, e.message), file=sys.stderr) sys.exit(EXIT_INPUT_FORMAT) except HMMError as e: msg = ("Encountered an hmm that wasn't present in the training data.\n" f"Offending HMMs were: {', '.join(e.hmms)}") print(msg, file=sys.stderr) sys.exit(EXIT_INPUT_FORMAT) pass except OSError as e: msg = ( "Encountered a system error.\n" "We can't control these, and they're usually related to your OS.\n" "Try running again.\n") print(msg, file=sys.stderr) print(e.strerror, file=sys.stderr) sys.exit(EXIT_SYSERR) except MemoryError: msg = ("Ran out of memory!\n" "Catastrophy shouldn't use much RAM, so check other " "processes and try running again.") print(msg, file=sys.stderr) sys.exit(EXIT_SYSERR) except KeyboardInterrupt: print("Received keyboard interrupt. Exiting.", file=sys.stderr) sys.exit(EXIT_KEYBOARD) except Exception as e: msg = ( "I'm so sorry, but we've encountered an unexpected error.\n" "This shouldn't happen, so please file a bug report with the " "authors.\nWe will be extremely grateful!\n\n" "You can email us at {}.\n" "Alternatively, you can file the issue directly on the repo " "<https://bitbucket.org/ccdm-curtin/catastrophy/issues>\n\n" "Please attach a copy of the following message:").format(__email__) print(e, file=sys.stderr) traceback.print_exc(file=sys.stderr) sys.exit(EXIT_UNKNOWN) return
def runner( nomenclatures_handle: Optional[TextIO], classes: str, inhmms: TextIO, inhandles: Sequence[TextIO], outhandle: BinaryIO, labels: Sequence[str], file_format: FileType, ): """ Runs the pipeline. """ if nomenclatures_handle is None: nomenclatures = data.nomenclatures() else: nomenclatures = json.load(nomenclatures_handle) required_nomenclatures = { "nomenclature1", "nomenclature2", "nomenclature3" } if set(nomenclatures.keys()) != required_nomenclatures: raise ValueError("The nomenclatures json file has invalid keys.") with open(classes, newline='') as handle: class_labels = NomenclatureClass.from_tsv(handle) # This checks if there are any labels in input filenames that aren't in # the classes tsv, and vice-versa. We want 1-1. if len(set(c.label for c in class_labels).symmetric_difference(labels)) != 0: raise ValueError("The file labels and the class labels are different.") class_labels_nomenclatures = { "nomenclature1": {t.nomenclature1 for t in class_labels}, "nomenclature2": {t.nomenclature2 for t in class_labels}, "nomenclature3": {t.nomenclature3 for t in class_labels}, } for nom, nom_class_set in class_labels_nomenclatures.items(): if len(nom_class_set.symmetric_difference(nomenclatures[nom])) != 0: raise ValueError(f"The nomenclatures and class files for {nom} " "don't have the same classes.") hmm_lengths = HMMLengths.read_hmm(inhmms) parsed = [ parsers.parse(h, format=file_format, hmm_lens=hmm_lengths) for h in inhandles ] required_cols = sorted(hmm_lengths.keys()) # Columns counts = cazy_counts_multi(parsed, labels, required_cols) model = Model.fit(counts, class_labels, nomenclatures, hmm_lengths) # The model_mean and model_components together make up the model. model.write(outhandle) return