Пример #1
0
def run_remove_background(args):
    """The full script for the command line tool to remove background RNA.

    Args:
        args: Inputs from the command line, already parsed using argparse.

    Note: Returns nothing, but writes output to a file(s) specified from
        command line.

    """

    # Load dataset, run inference, and write the output to a file.

    # Send logging messages to stdout as well as a log file.
    file_dir, file_base = os.path.split(args.output_file)
    file_name = os.path.splitext(os.path.basename(file_base))[0]
    log_file = os.path.join(file_dir, file_name + ".log")
    logging.basicConfig(level=logging.INFO,
                        format="cellbender:remove-background: %(message)s",
                        filename=log_file,
                        filemode="w")
    console = logging.StreamHandler()
    formatter = logging.Formatter("cellbender:remove-background: "
                                  "%(message)s")
    console.setFormatter(formatter)  # Use the same format for stdout.
    logging.getLogger('').addHandler(console)  # Log to stdout and a file.

    # Log the command as typed by user.
    logging.info("Command:\n" +
                 ' '.join(['cellbender', 'remove-background'] + sys.argv[2:]))

    # Log the start time.
    logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

    logging.info("Running remove-background")

    # Load data from file and choose barcodes and genes to analyze.
    try:
        dataset_obj = \
            SingleCellRNACountsDataset(input_file=args.input_file,
                                       expected_cell_count=
                                       args.expected_cell_count,
                                       total_droplet_barcodes=
                                       args.total_droplets,
                                       fraction_empties=args.fraction_empties,
                                       model_name=args.model,
                                       gene_blacklist=args.blacklisted_genes,
                                       low_count_threshold=
                                       args.low_count_threshold)
    except OSError:
        logging.error(f"OSError: Unable to open file {args.input_file}.")
        sys.exit(1)

    # Instantiate latent variable model and run full inference procedure.
    inferred_model = run_inference(dataset_obj, args)

    # Write outputs to file.
    try:
        dataset_obj.save_to_output_file(args.output_file,
                                        inferred_model,
                                        save_plots=True)

        logging.info("Completed remove-background.")
        logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S\n'))

    # The exception allows user to end inference prematurely with CTRL-C.
    except KeyboardInterrupt:

        # If partial output has been saved, delete it.
        full_file = args.output_file

        # Name of the filtered (cells only) file.
        file_dir, file_base = os.path.split(full_file)
        file_name = os.path.splitext(os.path.basename(file_base))[0]
        filtered_file = os.path.join(file_dir, file_name + "_filtered.h5")

        if os.path.exists(full_file):
            os.remove(full_file)

        if os.path.exists(filtered_file):
            os.remove(filtered_file)

        logging.info("Keyboard interrupt.  Terminated without saving.\n")
Пример #2
0
    def test_inference(self):
        """Run a basic tests doing inference on a synthetic dataset.

        Runs the inference procedure on CPU.

        """

        try:

            n_cells = 100

            # Generate a simulated dataset with ambient RNA.
            csr_barcode_gene_synthetic, _, _, _ = \
                simulate_ambient_dataset(n_cells=n_cells, n_empty=3 * n_cells,
                                         clusters=1, n_genes=1000,
                                         d_cell=2000, d_empty=100,
                                         ambient_different=False)

            # Fake some parsed command line inputs.
            args = ObjectWithAttributes()
            args.use_cuda = False
            args.z_hidden_dims = [100]
            args.d_hidden_dims = [10, 2]
            args.p_hidden_dims = [100, 10]
            args.z_dim = 10
            args.learning_rate = 0.001
            args.epochs = 3
            args.model = ["full"]
            args.use_decaying_average_baseline = False
            args.fraction_empties = 0.2
            args.training_fraction = 0.8

            args.expected_cell_count = n_cells

            # Wrap simulated count matrix in a Dataset object.
            dataset_obj = Dataset(transformation=transform.IdentityTransform())
            dataset_obj.data = \
                {'matrix': csr_barcode_gene_synthetic,
                 'gene_names':
                     np.array([f'g{n}' for n in
                               range(csr_barcode_gene_synthetic.shape[1])]),
                 'barcodes':
                     np.array([f'bc{n}' for n in
                               range(csr_barcode_gene_synthetic.shape[0])])}
            dataset_obj.priors['n_cells'] = n_cells
            dataset_obj._trim_dataset_for_analysis()
            dataset_obj._estimate_priors()

            # Run inference on this simulated dataset.
            inferred_model = run_inference(dataset_obj, args)

            # Get encodings from the trained model.
            z, d, p = cellbender.remove_background.model.get_encodings(
                inferred_model, dataset_obj)

            # Make the background-subtracted dataset.
            inferred_count_matrix = cellbender.remove_background.model.\
                get_count_matrix_from_encodings(z, d, p,
                                                inferred_model,
                                                dataset_obj)

            # Get the inferred background RNA expression from the model.
            ambient_expression = cellbender.remove_background.model.\
                get_ambient_expression()

            # TODO:

            # Save the model.

            # Clear the model from memory.

            # Load the model.

            return 1

        except TestConsole.failureException:

            return 0
Пример #3
0
def run_remove_background(args):
    """The full script for the command line tool to remove background RNA.

    Args:
        args: Inputs from the command line, already parsed using argparse.

    Note: Returns nothing, but writes output to a file(s) specified from command
    line.

    """

    # Set up the count data transformation.
    if args.transform[0] == "identity":
        trans = transform.IdentityTransform()
    elif args.transform[0] == "log":
        trans = transform.LogTransform(scale_factor=1.)
    elif args.transform[0] == "sqrt":
        trans = transform.SqrtTransform(scale_factor=1.)
    else:
        raise NotImplementedError(f"transform was set to {args.transform[0]}, "
                                  f"which is not implemented.")

    # If one model / cell-count specified for several files, broadcast it.
    if len(args.model) == 1:
        args.model = [args.model[0] for _ in range(len(args.input_files))]
    if len(args.expected_cell_count) == 1:
        args.expected_cell_count = \
            [args.expected_cell_count[0] for _ in range(len(args.input_files))]
    if len(args.additional_barcodes) == 1:
        args.additional_barcodes = \
            [args.additional_barcodes[0] for _ in range(len(args.input_files))]

    # Load each dataset, run inference, and write the output to a file.
    for i, file in enumerate(args.input_files):

        # Send logging messages to stdout as well as a log file.
        # TODO: this doesn't work for multiple file batching.
        file_dir, file_base = os.path.split(args.output_files[i])
        file_name = os.path.splitext(os.path.basename(file_base))[0]
        log_file = os.path.join(file_dir, file_name + ".log")
        logging.basicConfig(level=logging.INFO,
                            format="cellbender:remove_background: %(message)s",
                            filename=log_file,
                            filemode="w")
        console = logging.StreamHandler()
        formatter = logging.Formatter(
            "cellbender:remove_background: %(message)s")
        console.setFormatter(formatter)  # Use the same format for stdout.
        logging.getLogger('').addHandler(
            console)  # Log to stdout as well as file.

        logging.info("Running remove_background")

        # Load data from file and choose barcodes and genes to analyze.
        try:
            dataset_obj = Dataset(
                transformation=trans,
                input_file=file,
                expected_cell_count=args.expected_cell_count[i],
                num_transition_barcodes=args.additional_barcodes[i],
                fraction_empties=args.fraction_empties,
                model_name=args.model[i],
                gene_blacklist=args.blacklisted_genes,
                low_count_threshold=args.low_count_threshold)
        except OSError:
            logging.error(f"OSError: Unable to open file {file}.")
            continue

        # Instantiate latent variable model and run full inference procedure.
        inferred_model = run_inference(dataset_obj, args)

        # Write outputs to file.
        try:
            dataset_obj.save_to_output_file(args.output_files[i],
                                            inferred_model,
                                            save_plots=True)

            logging.info("Completed remove_background.\n")

        # The exception allows user to end inference prematurely with CTRL-C.
        except KeyboardInterrupt:

            # If partial output has been saved, delete it.
            full_file = args.output_files[i]

            # Name of the filtered (cells only) file.
            file_dir, file_base = os.path.split(full_file)
            file_name = os.path.splitext(os.path.basename(file_base))[0]
            filtered_file = os.path.join(file_dir, file_name + "_filtered.h5")

            if os.path.exists(full_file):
                os.remove(full_file)

            if os.path.exists(filtered_file):
                os.remove(filtered_file)

            logging.info("Keyboard interrupt.  Terminated without saving.\n")

        # Save trained model to file with same filename, but as .model.
        # file_dir, file_base = os.path.split(args.output_files[i])
        # file_name = os.path.splitext(os.path.basename(file_base))[0]
        # inferred_model.save_model_to_file(os.path.join(file_dir, file_name))
        del dataset_obj
        del inferred_model
Пример #4
0
    def test_inference(self):
        """Run a basic tests doing inference on a synthetic dataset.

        Runs the inference procedure on CPU.

        """

        try:

            n_cells = 100

            # Generate a simulated dataset with ambient RNA.
            csr_barcode_gene_synthetic, _, _, _ = \
                simulate_dataset_with_ambient_rna(n_cells=n_cells,
                                                  n_empty=3 * n_cells,
                                                  clusters=1, n_genes=1000,
                                                  d_cell=2000, d_empty=100,
                                                  ambient_different=False)

            # Fake some parsed command line inputs.
            args = ObjectWithAttributes()
            args.use_cuda = False
            args.z_hidden_dims = [100]
            args.d_hidden_dims = [10, 2]
            args.p_hidden_dims = [100, 10]
            args.z_dim = 10
            args.learning_rate = 0.001
            args.epochs = 10
            args.model = "full"
            args.fraction_empties = 0.5
            args.use_jit = True
            args.training_fraction = 0.9

            args.expected_cell_count = n_cells

            # Wrap simulated count matrix in a Dataset object.
            dataset_obj = SingleCellRNACountsDataset()
            dataset_obj.data = \
                {'matrix': csr_barcode_gene_synthetic,
                 'gene_names':
                     np.array([f'g{n}' for n in
                               range(csr_barcode_gene_synthetic.shape[1])]),
                 'barcodes':
                     np.array([f'bc{n}' for n in
                               range(csr_barcode_gene_synthetic.shape[0])])}
            dataset_obj.priors['n_cells'] = n_cells
            dataset_obj._trim_dataset_for_analysis()
            dataset_obj._estimate_priors()

            # Run inference on this simulated dataset.
            inferred_model = run_inference(dataset_obj, args)

            # Get encodings from the trained model.
            z, d, p = cellbender.remove_background.model.\
                get_encodings(inferred_model, dataset_obj)

            # Make the background-subtracted dataset.
            inferred_count_matrix = cellbender.remove_background.model.\
                generate_maximum_a_posteriori_count_matrix(z, d, p,
                                                           inferred_model,
                                                           dataset_obj)

            # Get the inferred background RNA expression from the model.
            ambient_expression = cellbender.remove_background.model.\
                get_ambient_expression_from_pyro_param_store()

            return 1

        except TestConsole.failureException:

            return 0