def run_remove_background(args): """The full script for the command line tool to remove background RNA. Args: args: Inputs from the command line, already parsed using argparse. Note: Returns nothing, but writes output to a file(s) specified from command line. """ # Load dataset, run inference, and write the output to a file. # Send logging messages to stdout as well as a log file. file_dir, file_base = os.path.split(args.output_file) file_name = os.path.splitext(os.path.basename(file_base))[0] log_file = os.path.join(file_dir, file_name + ".log") logging.basicConfig(level=logging.INFO, format="cellbender:remove-background: %(message)s", filename=log_file, filemode="w") console = logging.StreamHandler() formatter = logging.Formatter("cellbender:remove-background: " "%(message)s") console.setFormatter(formatter) # Use the same format for stdout. logging.getLogger('').addHandler(console) # Log to stdout and a file. # Log the command as typed by user. logging.info("Command:\n" + ' '.join(['cellbender', 'remove-background'] + sys.argv[2:])) # Log the start time. logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) logging.info("Running remove-background") # Load data from file and choose barcodes and genes to analyze. try: dataset_obj = \ SingleCellRNACountsDataset(input_file=args.input_file, expected_cell_count= args.expected_cell_count, total_droplet_barcodes= args.total_droplets, fraction_empties=args.fraction_empties, model_name=args.model, gene_blacklist=args.blacklisted_genes, low_count_threshold= args.low_count_threshold) except OSError: logging.error(f"OSError: Unable to open file {args.input_file}.") sys.exit(1) # Instantiate latent variable model and run full inference procedure. inferred_model = run_inference(dataset_obj, args) # Write outputs to file. try: dataset_obj.save_to_output_file(args.output_file, inferred_model, save_plots=True) logging.info("Completed remove-background.") logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S\n')) # The exception allows user to end inference prematurely with CTRL-C. except KeyboardInterrupt: # If partial output has been saved, delete it. full_file = args.output_file # Name of the filtered (cells only) file. file_dir, file_base = os.path.split(full_file) file_name = os.path.splitext(os.path.basename(file_base))[0] filtered_file = os.path.join(file_dir, file_name + "_filtered.h5") if os.path.exists(full_file): os.remove(full_file) if os.path.exists(filtered_file): os.remove(filtered_file) logging.info("Keyboard interrupt. Terminated without saving.\n")
def test_inference(self): """Run a basic tests doing inference on a synthetic dataset. Runs the inference procedure on CPU. """ try: n_cells = 100 # Generate a simulated dataset with ambient RNA. csr_barcode_gene_synthetic, _, _, _ = \ simulate_ambient_dataset(n_cells=n_cells, n_empty=3 * n_cells, clusters=1, n_genes=1000, d_cell=2000, d_empty=100, ambient_different=False) # Fake some parsed command line inputs. args = ObjectWithAttributes() args.use_cuda = False args.z_hidden_dims = [100] args.d_hidden_dims = [10, 2] args.p_hidden_dims = [100, 10] args.z_dim = 10 args.learning_rate = 0.001 args.epochs = 3 args.model = ["full"] args.use_decaying_average_baseline = False args.fraction_empties = 0.2 args.training_fraction = 0.8 args.expected_cell_count = n_cells # Wrap simulated count matrix in a Dataset object. dataset_obj = Dataset(transformation=transform.IdentityTransform()) dataset_obj.data = \ {'matrix': csr_barcode_gene_synthetic, 'gene_names': np.array([f'g{n}' for n in range(csr_barcode_gene_synthetic.shape[1])]), 'barcodes': np.array([f'bc{n}' for n in range(csr_barcode_gene_synthetic.shape[0])])} dataset_obj.priors['n_cells'] = n_cells dataset_obj._trim_dataset_for_analysis() dataset_obj._estimate_priors() # Run inference on this simulated dataset. inferred_model = run_inference(dataset_obj, args) # Get encodings from the trained model. z, d, p = cellbender.remove_background.model.get_encodings( inferred_model, dataset_obj) # Make the background-subtracted dataset. inferred_count_matrix = cellbender.remove_background.model.\ get_count_matrix_from_encodings(z, d, p, inferred_model, dataset_obj) # Get the inferred background RNA expression from the model. ambient_expression = cellbender.remove_background.model.\ get_ambient_expression() # TODO: # Save the model. # Clear the model from memory. # Load the model. return 1 except TestConsole.failureException: return 0
def run_remove_background(args): """The full script for the command line tool to remove background RNA. Args: args: Inputs from the command line, already parsed using argparse. Note: Returns nothing, but writes output to a file(s) specified from command line. """ # Set up the count data transformation. if args.transform[0] == "identity": trans = transform.IdentityTransform() elif args.transform[0] == "log": trans = transform.LogTransform(scale_factor=1.) elif args.transform[0] == "sqrt": trans = transform.SqrtTransform(scale_factor=1.) else: raise NotImplementedError(f"transform was set to {args.transform[0]}, " f"which is not implemented.") # If one model / cell-count specified for several files, broadcast it. if len(args.model) == 1: args.model = [args.model[0] for _ in range(len(args.input_files))] if len(args.expected_cell_count) == 1: args.expected_cell_count = \ [args.expected_cell_count[0] for _ in range(len(args.input_files))] if len(args.additional_barcodes) == 1: args.additional_barcodes = \ [args.additional_barcodes[0] for _ in range(len(args.input_files))] # Load each dataset, run inference, and write the output to a file. for i, file in enumerate(args.input_files): # Send logging messages to stdout as well as a log file. # TODO: this doesn't work for multiple file batching. file_dir, file_base = os.path.split(args.output_files[i]) file_name = os.path.splitext(os.path.basename(file_base))[0] log_file = os.path.join(file_dir, file_name + ".log") logging.basicConfig(level=logging.INFO, format="cellbender:remove_background: %(message)s", filename=log_file, filemode="w") console = logging.StreamHandler() formatter = logging.Formatter( "cellbender:remove_background: %(message)s") console.setFormatter(formatter) # Use the same format for stdout. logging.getLogger('').addHandler( console) # Log to stdout as well as file. logging.info("Running remove_background") # Load data from file and choose barcodes and genes to analyze. try: dataset_obj = Dataset( transformation=trans, input_file=file, expected_cell_count=args.expected_cell_count[i], num_transition_barcodes=args.additional_barcodes[i], fraction_empties=args.fraction_empties, model_name=args.model[i], gene_blacklist=args.blacklisted_genes, low_count_threshold=args.low_count_threshold) except OSError: logging.error(f"OSError: Unable to open file {file}.") continue # Instantiate latent variable model and run full inference procedure. inferred_model = run_inference(dataset_obj, args) # Write outputs to file. try: dataset_obj.save_to_output_file(args.output_files[i], inferred_model, save_plots=True) logging.info("Completed remove_background.\n") # The exception allows user to end inference prematurely with CTRL-C. except KeyboardInterrupt: # If partial output has been saved, delete it. full_file = args.output_files[i] # Name of the filtered (cells only) file. file_dir, file_base = os.path.split(full_file) file_name = os.path.splitext(os.path.basename(file_base))[0] filtered_file = os.path.join(file_dir, file_name + "_filtered.h5") if os.path.exists(full_file): os.remove(full_file) if os.path.exists(filtered_file): os.remove(filtered_file) logging.info("Keyboard interrupt. Terminated without saving.\n") # Save trained model to file with same filename, but as .model. # file_dir, file_base = os.path.split(args.output_files[i]) # file_name = os.path.splitext(os.path.basename(file_base))[0] # inferred_model.save_model_to_file(os.path.join(file_dir, file_name)) del dataset_obj del inferred_model
def test_inference(self): """Run a basic tests doing inference on a synthetic dataset. Runs the inference procedure on CPU. """ try: n_cells = 100 # Generate a simulated dataset with ambient RNA. csr_barcode_gene_synthetic, _, _, _ = \ simulate_dataset_with_ambient_rna(n_cells=n_cells, n_empty=3 * n_cells, clusters=1, n_genes=1000, d_cell=2000, d_empty=100, ambient_different=False) # Fake some parsed command line inputs. args = ObjectWithAttributes() args.use_cuda = False args.z_hidden_dims = [100] args.d_hidden_dims = [10, 2] args.p_hidden_dims = [100, 10] args.z_dim = 10 args.learning_rate = 0.001 args.epochs = 10 args.model = "full" args.fraction_empties = 0.5 args.use_jit = True args.training_fraction = 0.9 args.expected_cell_count = n_cells # Wrap simulated count matrix in a Dataset object. dataset_obj = SingleCellRNACountsDataset() dataset_obj.data = \ {'matrix': csr_barcode_gene_synthetic, 'gene_names': np.array([f'g{n}' for n in range(csr_barcode_gene_synthetic.shape[1])]), 'barcodes': np.array([f'bc{n}' for n in range(csr_barcode_gene_synthetic.shape[0])])} dataset_obj.priors['n_cells'] = n_cells dataset_obj._trim_dataset_for_analysis() dataset_obj._estimate_priors() # Run inference on this simulated dataset. inferred_model = run_inference(dataset_obj, args) # Get encodings from the trained model. z, d, p = cellbender.remove_background.model.\ get_encodings(inferred_model, dataset_obj) # Make the background-subtracted dataset. inferred_count_matrix = cellbender.remove_background.model.\ generate_maximum_a_posteriori_count_matrix(z, d, p, inferred_model, dataset_obj) # Get the inferred background RNA expression from the model. ambient_expression = cellbender.remove_background.model.\ get_ambient_expression_from_pyro_param_store() return 1 except TestConsole.failureException: return 0