def run_inference(dataset_obj: SingleCellRNACountsDataset, args) -> RemoveBackgroundPyroModel: """Run a full inference procedure, training a latent variable model. Args: dataset_obj: Input data in the form of a SingleCellRNACountsDataset object. args: Input command line parsed arguments. Returns: model: cellbender.model.RemoveBackgroundPyroModel that has had inference run. """ # Get the trimmed count matrix (transformed if called for). count_matrix = dataset_obj.get_count_matrix() # Configure pyro options (skip validations to improve speed). pyro.enable_validation(False) pyro.distributions.enable_validation(False) pyro.set_rng_seed(0) pyro.clear_param_store() # Set up the variational autoencoder: # Encoder. encoder_z = EncodeZ(input_dim=count_matrix.shape[1], hidden_dims=args.z_hidden_dims, output_dim=args.z_dim, input_transform='normalize') encoder_other = EncodeNonZLatents( n_genes=count_matrix.shape[1], z_dim=args.z_dim, hidden_dims=consts.ENC_HIDDEN_DIMS, log_count_crossover=dataset_obj.priors['log_counts_crossover'], prior_log_cell_counts=np.log1p(dataset_obj.priors['cell_counts']), input_transform='normalize') encoder = CompositeEncoder({'z': encoder_z, 'other': encoder_other}) # Decoder. decoder = Decoder(input_dim=args.z_dim, hidden_dims=args.z_hidden_dims[::-1], output_dim=count_matrix.shape[1]) # Set up the pyro model for variational inference. model = RemoveBackgroundPyroModel(model_type=args.model, encoder=encoder, decoder=decoder, dataset_obj=dataset_obj, use_cuda=args.use_cuda) # Load the dataset into DataLoaders. frac = args.training_fraction # Fraction of barcodes to use for training batch_size = int( min(300, frac * dataset_obj.analyzed_barcode_inds.size / 2)) train_loader, test_loader = \ prep_data_for_training(dataset=count_matrix, empty_drop_dataset= dataset_obj.get_count_matrix_empties(), random_state=dataset_obj.random, batch_size=batch_size, training_fraction=frac, fraction_empties=args.fraction_empties, shuffle=True, use_cuda=args.use_cuda) # Set up the optimizer. optimizer = pyro.optim.clipped_adam.ClippedAdam optimizer_args = {'lr': args.learning_rate, 'clip_norm': 10.} # Set up a learning rate scheduler. minibatches_per_epoch = int( np.ceil(len(train_loader) / train_loader.batch_size).item()) scheduler_args = { 'optimizer': optimizer, 'max_lr': args.learning_rate * 10, 'steps_per_epoch': minibatches_per_epoch, 'epochs': args.epochs, 'optim_args': optimizer_args } scheduler = pyro.optim.OneCycleLR(scheduler_args) # Determine the loss function. if args.use_jit: # Call guide() once as a warm-up. model.guide( torch.zeros([10, dataset_obj.analyzed_gene_inds.size ]).to(model.device)) if args.model == "simple": loss_function = JitTrace_ELBO() else: loss_function = JitTraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) else: if args.model == "simple": loss_function = Trace_ELBO() else: loss_function = TraceEnum_ELBO(max_plate_nesting=1) # Set up the inference process. svi = SVI(model.model, model.guide, scheduler, loss=loss_function) # Run training. run_training(model, svi, train_loader, test_loader, epochs=args.epochs, test_freq=5) return model
def test_inference(self): """Run a basic tests doing inference on a synthetic dataset. Runs the inference procedure on CPU. """ try: n_cells = 100 # Generate a simulated dataset with ambient RNA. csr_barcode_gene_synthetic, _, _, _ = \ simulate_dataset_with_ambient_rna(n_cells=n_cells, n_empty=3 * n_cells, clusters=1, n_genes=1000, d_cell=2000, d_empty=100, ambient_different=False) # Fake some parsed command line inputs. args = ObjectWithAttributes() args.use_cuda = False args.z_hidden_dims = [100] args.d_hidden_dims = [10, 2] args.p_hidden_dims = [100, 10] args.z_dim = 10 args.learning_rate = 0.001 args.epochs = 10 args.model = "full" args.fraction_empties = 0.5 args.use_jit = True args.training_fraction = 0.9 args.expected_cell_count = n_cells # Wrap simulated count matrix in a Dataset object. dataset_obj = SingleCellRNACountsDataset() dataset_obj.data = \ {'matrix': csr_barcode_gene_synthetic, 'gene_names': np.array([f'g{n}' for n in range(csr_barcode_gene_synthetic.shape[1])]), 'barcodes': np.array([f'bc{n}' for n in range(csr_barcode_gene_synthetic.shape[0])])} dataset_obj.priors['n_cells'] = n_cells dataset_obj._trim_dataset_for_analysis() dataset_obj._estimate_priors() # Run inference on this simulated dataset. inferred_model = run_inference(dataset_obj, args) # Get encodings from the trained model. z, d, p = cellbender.remove_background.model.\ get_encodings(inferred_model, dataset_obj) # Make the background-subtracted dataset. inferred_count_matrix = cellbender.remove_background.model.\ generate_maximum_a_posteriori_count_matrix(z, d, p, inferred_model, dataset_obj) # Get the inferred background RNA expression from the model. ambient_expression = cellbender.remove_background.model.\ get_ambient_expression_from_pyro_param_store() return 1 except TestConsole.failureException: return 0
def run_remove_background(args): """The full script for the command line tool to remove background RNA. Args: args: Inputs from the command line, already parsed using argparse. Note: Returns nothing, but writes output to a file(s) specified from command line. """ # Load dataset, run inference, and write the output to a file. # Send logging messages to stdout as well as a log file. file_dir, file_base = os.path.split(args.output_file) file_name = os.path.splitext(os.path.basename(file_base))[0] log_file = os.path.join(file_dir, file_name + ".log") logging.basicConfig(level=logging.INFO, format="cellbender:remove-background: %(message)s", filename=log_file, filemode="w") console = logging.StreamHandler() formatter = logging.Formatter("cellbender:remove-background: " "%(message)s") console.setFormatter(formatter) # Use the same format for stdout. logging.getLogger('').addHandler(console) # Log to stdout and a file. # Log the command as typed by user. logging.info("Command:\n" + ' '.join(['cellbender', 'remove-background'] + sys.argv[2:])) # Log the start time. logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) logging.info("Running remove-background") # Load data from file and choose barcodes and genes to analyze. try: dataset_obj = \ SingleCellRNACountsDataset(input_file=args.input_file, expected_cell_count= args.expected_cell_count, total_droplet_barcodes= args.total_droplets, fraction_empties=args.fraction_empties, model_name=args.model, gene_blacklist=args.blacklisted_genes, low_count_threshold= args.low_count_threshold) except OSError: logging.error(f"OSError: Unable to open file {args.input_file}.") sys.exit(1) # Instantiate latent variable model and run full inference procedure. inferred_model = run_inference(dataset_obj, args) # Write outputs to file. try: dataset_obj.save_to_output_file(args.output_file, inferred_model, save_plots=True) logging.info("Completed remove-background.") logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S\n')) # The exception allows user to end inference prematurely with CTRL-C. except KeyboardInterrupt: # If partial output has been saved, delete it. full_file = args.output_file # Name of the filtered (cells only) file. file_dir, file_base = os.path.split(full_file) file_name = os.path.splitext(os.path.basename(file_base))[0] filtered_file = os.path.join(file_dir, file_name + "_filtered.h5") if os.path.exists(full_file): os.remove(full_file) if os.path.exists(filtered_file): os.remove(filtered_file) logging.info("Keyboard interrupt. Terminated without saving.\n")
def run_inference(dataset_obj: SingleCellRNACountsDataset, args) -> RemoveBackgroundPyroModel: """Run a full inference procedure, training a latent variable model. Args: dataset_obj: Input data in the form of a SingleCellRNACountsDataset object. args: Input command line parsed arguments. Returns: model: cellbender.model.RemoveBackgroundPyroModel that has had inference run. """ # Get the trimmed count matrix (transformed if called for). count_matrix = dataset_obj.get_count_matrix() # Configure pyro options (skip validations to improve speed). pyro.enable_validation(False) pyro.distributions.enable_validation(False) pyro.set_rng_seed(0) pyro.clear_param_store() # Set up the variational autoencoder: # Encoder. encoder_z = EncodeZ(input_dim=count_matrix.shape[1], hidden_dims=args.z_hidden_dims, output_dim=args.z_dim, input_transform='normalize') encoder_d = EncodeD( input_dim=count_matrix.shape[1], hidden_dims=args.d_hidden_dims, output_dim=1, log_count_crossover=dataset_obj.priors['log_counts_crossover']) if args.model == "simple": # If using the simple model, there is no need for p. encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d}) else: # Models that include empty droplets. encoder_p = EncodeNonEmptyDropletLogitProb( input_dim=count_matrix.shape[1], hidden_dims=args.p_hidden_dims, output_dim=1, input_transform='normalize', log_count_crossover=dataset_obj.priors['log_counts_crossover']) encoder = CompositeEncoder({ 'z': encoder_z, 'd_loc': encoder_d, 'p_y': encoder_p }) # Decoder. decoder = Decoder(input_dim=args.z_dim, hidden_dims=args.z_hidden_dims[::-1], output_dim=count_matrix.shape[1]) # Set up the pyro model for variational inference. model = RemoveBackgroundPyroModel(model_type=args.model, encoder=encoder, decoder=decoder, dataset_obj=dataset_obj, use_cuda=args.use_cuda) # Set up the optimizer. adam_args = {"lr": args.learning_rate} optimizer = ClippedAdam(adam_args) # Determine the loss function. if args.use_jit: loss_function = JitTraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) else: loss_function = TraceEnum_ELBO(max_plate_nesting=1) if args.model == "simple": if args.use_jit: loss_function = JitTrace_ELBO() else: loss_function = Trace_ELBO() # Set up the inference process. svi = SVI(model.model, model.guide, optimizer, loss=loss_function) # Load the dataset into DataLoaders. frac = args.training_fraction # Fraction of barcodes to use for training batch_size = int( min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2)) train_loader, test_loader = \ prep_data_for_training(dataset=count_matrix, empty_drop_dataset= dataset_obj.get_count_matrix_empties(), random_state=dataset_obj.random, batch_size=batch_size, training_fraction=frac, fraction_empties=args.fraction_empties, shuffle=True, use_cuda=args.use_cuda) # Run training. run_training(model, svi, train_loader, test_loader, epochs=args.epochs, test_freq=10) return model