Beispiel #1
0
def run_inference(dataset_obj: SingleCellRNACountsDataset,
                  args) -> RemoveBackgroundPyroModel:
    """Run a full inference procedure, training a latent variable model.

    Args:
        dataset_obj: Input data in the form of a SingleCellRNACountsDataset
            object.
        args: Input command line parsed arguments.

    Returns:
         model: cellbender.model.RemoveBackgroundPyroModel that has had
            inference run.

    """

    # Get the trimmed count matrix (transformed if called for).
    count_matrix = dataset_obj.get_count_matrix()

    # Configure pyro options (skip validations to improve speed).
    pyro.enable_validation(False)
    pyro.distributions.enable_validation(False)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    # Set up the variational autoencoder:

    # Encoder.
    encoder_z = EncodeZ(input_dim=count_matrix.shape[1],
                        hidden_dims=args.z_hidden_dims,
                        output_dim=args.z_dim,
                        input_transform='normalize')

    encoder_other = EncodeNonZLatents(
        n_genes=count_matrix.shape[1],
        z_dim=args.z_dim,
        hidden_dims=consts.ENC_HIDDEN_DIMS,
        log_count_crossover=dataset_obj.priors['log_counts_crossover'],
        prior_log_cell_counts=np.log1p(dataset_obj.priors['cell_counts']),
        input_transform='normalize')

    encoder = CompositeEncoder({'z': encoder_z, 'other': encoder_other})

    # Decoder.
    decoder = Decoder(input_dim=args.z_dim,
                      hidden_dims=args.z_hidden_dims[::-1],
                      output_dim=count_matrix.shape[1])

    # Set up the pyro model for variational inference.
    model = RemoveBackgroundPyroModel(model_type=args.model,
                                      encoder=encoder,
                                      decoder=decoder,
                                      dataset_obj=dataset_obj,
                                      use_cuda=args.use_cuda)

    # Load the dataset into DataLoaders.
    frac = args.training_fraction  # Fraction of barcodes to use for training
    batch_size = int(
        min(300, frac * dataset_obj.analyzed_barcode_inds.size / 2))
    train_loader, test_loader = \
        prep_data_for_training(dataset=count_matrix,
                               empty_drop_dataset=
                               dataset_obj.get_count_matrix_empties(),
                               random_state=dataset_obj.random,
                               batch_size=batch_size,
                               training_fraction=frac,
                               fraction_empties=args.fraction_empties,
                               shuffle=True,
                               use_cuda=args.use_cuda)

    # Set up the optimizer.
    optimizer = pyro.optim.clipped_adam.ClippedAdam
    optimizer_args = {'lr': args.learning_rate, 'clip_norm': 10.}

    # Set up a learning rate scheduler.
    minibatches_per_epoch = int(
        np.ceil(len(train_loader) / train_loader.batch_size).item())
    scheduler_args = {
        'optimizer': optimizer,
        'max_lr': args.learning_rate * 10,
        'steps_per_epoch': minibatches_per_epoch,
        'epochs': args.epochs,
        'optim_args': optimizer_args
    }
    scheduler = pyro.optim.OneCycleLR(scheduler_args)

    # Determine the loss function.
    if args.use_jit:

        # Call guide() once as a warm-up.
        model.guide(
            torch.zeros([10, dataset_obj.analyzed_gene_inds.size
                         ]).to(model.device))

        if args.model == "simple":
            loss_function = JitTrace_ELBO()
        else:
            loss_function = JitTraceEnum_ELBO(max_plate_nesting=1,
                                              strict_enumeration_warning=False)
    else:

        if args.model == "simple":
            loss_function = Trace_ELBO()
        else:
            loss_function = TraceEnum_ELBO(max_plate_nesting=1)

    # Set up the inference process.
    svi = SVI(model.model, model.guide, scheduler, loss=loss_function)

    # Run training.
    run_training(model,
                 svi,
                 train_loader,
                 test_loader,
                 epochs=args.epochs,
                 test_freq=5)

    return model
Beispiel #2
0
    def test_inference(self):
        """Run a basic tests doing inference on a synthetic dataset.

        Runs the inference procedure on CPU.

        """

        try:

            n_cells = 100

            # Generate a simulated dataset with ambient RNA.
            csr_barcode_gene_synthetic, _, _, _ = \
                simulate_dataset_with_ambient_rna(n_cells=n_cells,
                                                  n_empty=3 * n_cells,
                                                  clusters=1, n_genes=1000,
                                                  d_cell=2000, d_empty=100,
                                                  ambient_different=False)

            # Fake some parsed command line inputs.
            args = ObjectWithAttributes()
            args.use_cuda = False
            args.z_hidden_dims = [100]
            args.d_hidden_dims = [10, 2]
            args.p_hidden_dims = [100, 10]
            args.z_dim = 10
            args.learning_rate = 0.001
            args.epochs = 10
            args.model = "full"
            args.fraction_empties = 0.5
            args.use_jit = True
            args.training_fraction = 0.9

            args.expected_cell_count = n_cells

            # Wrap simulated count matrix in a Dataset object.
            dataset_obj = SingleCellRNACountsDataset()
            dataset_obj.data = \
                {'matrix': csr_barcode_gene_synthetic,
                 'gene_names':
                     np.array([f'g{n}' for n in
                               range(csr_barcode_gene_synthetic.shape[1])]),
                 'barcodes':
                     np.array([f'bc{n}' for n in
                               range(csr_barcode_gene_synthetic.shape[0])])}
            dataset_obj.priors['n_cells'] = n_cells
            dataset_obj._trim_dataset_for_analysis()
            dataset_obj._estimate_priors()

            # Run inference on this simulated dataset.
            inferred_model = run_inference(dataset_obj, args)

            # Get encodings from the trained model.
            z, d, p = cellbender.remove_background.model.\
                get_encodings(inferred_model, dataset_obj)

            # Make the background-subtracted dataset.
            inferred_count_matrix = cellbender.remove_background.model.\
                generate_maximum_a_posteriori_count_matrix(z, d, p,
                                                           inferred_model,
                                                           dataset_obj)

            # Get the inferred background RNA expression from the model.
            ambient_expression = cellbender.remove_background.model.\
                get_ambient_expression_from_pyro_param_store()

            return 1

        except TestConsole.failureException:

            return 0
Beispiel #3
0
def run_remove_background(args):
    """The full script for the command line tool to remove background RNA.

    Args:
        args: Inputs from the command line, already parsed using argparse.

    Note: Returns nothing, but writes output to a file(s) specified from
        command line.

    """

    # Load dataset, run inference, and write the output to a file.

    # Send logging messages to stdout as well as a log file.
    file_dir, file_base = os.path.split(args.output_file)
    file_name = os.path.splitext(os.path.basename(file_base))[0]
    log_file = os.path.join(file_dir, file_name + ".log")
    logging.basicConfig(level=logging.INFO,
                        format="cellbender:remove-background: %(message)s",
                        filename=log_file,
                        filemode="w")
    console = logging.StreamHandler()
    formatter = logging.Formatter("cellbender:remove-background: "
                                  "%(message)s")
    console.setFormatter(formatter)  # Use the same format for stdout.
    logging.getLogger('').addHandler(console)  # Log to stdout and a file.

    # Log the command as typed by user.
    logging.info("Command:\n" +
                 ' '.join(['cellbender', 'remove-background'] + sys.argv[2:]))

    # Log the start time.
    logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

    logging.info("Running remove-background")

    # Load data from file and choose barcodes and genes to analyze.
    try:
        dataset_obj = \
            SingleCellRNACountsDataset(input_file=args.input_file,
                                       expected_cell_count=
                                       args.expected_cell_count,
                                       total_droplet_barcodes=
                                       args.total_droplets,
                                       fraction_empties=args.fraction_empties,
                                       model_name=args.model,
                                       gene_blacklist=args.blacklisted_genes,
                                       low_count_threshold=
                                       args.low_count_threshold)
    except OSError:
        logging.error(f"OSError: Unable to open file {args.input_file}.")
        sys.exit(1)

    # Instantiate latent variable model and run full inference procedure.
    inferred_model = run_inference(dataset_obj, args)

    # Write outputs to file.
    try:
        dataset_obj.save_to_output_file(args.output_file,
                                        inferred_model,
                                        save_plots=True)

        logging.info("Completed remove-background.")
        logging.info(datetime.now().strftime('%Y-%m-%d %H:%M:%S\n'))

    # The exception allows user to end inference prematurely with CTRL-C.
    except KeyboardInterrupt:

        # If partial output has been saved, delete it.
        full_file = args.output_file

        # Name of the filtered (cells only) file.
        file_dir, file_base = os.path.split(full_file)
        file_name = os.path.splitext(os.path.basename(file_base))[0]
        filtered_file = os.path.join(file_dir, file_name + "_filtered.h5")

        if os.path.exists(full_file):
            os.remove(full_file)

        if os.path.exists(filtered_file):
            os.remove(filtered_file)

        logging.info("Keyboard interrupt.  Terminated without saving.\n")
Beispiel #4
0
def run_inference(dataset_obj: SingleCellRNACountsDataset,
                  args) -> RemoveBackgroundPyroModel:
    """Run a full inference procedure, training a latent variable model.

    Args:
        dataset_obj: Input data in the form of a SingleCellRNACountsDataset
            object.
        args: Input command line parsed arguments.

    Returns:
         model: cellbender.model.RemoveBackgroundPyroModel that has had
            inference run.

    """

    # Get the trimmed count matrix (transformed if called for).
    count_matrix = dataset_obj.get_count_matrix()

    # Configure pyro options (skip validations to improve speed).
    pyro.enable_validation(False)
    pyro.distributions.enable_validation(False)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    # Set up the variational autoencoder:

    # Encoder.
    encoder_z = EncodeZ(input_dim=count_matrix.shape[1],
                        hidden_dims=args.z_hidden_dims,
                        output_dim=args.z_dim,
                        input_transform='normalize')

    encoder_d = EncodeD(
        input_dim=count_matrix.shape[1],
        hidden_dims=args.d_hidden_dims,
        output_dim=1,
        log_count_crossover=dataset_obj.priors['log_counts_crossover'])

    if args.model == "simple":

        # If using the simple model, there is no need for p.
        encoder = CompositeEncoder({'z': encoder_z, 'd_loc': encoder_d})

    else:

        # Models that include empty droplets.
        encoder_p = EncodeNonEmptyDropletLogitProb(
            input_dim=count_matrix.shape[1],
            hidden_dims=args.p_hidden_dims,
            output_dim=1,
            input_transform='normalize',
            log_count_crossover=dataset_obj.priors['log_counts_crossover'])
        encoder = CompositeEncoder({
            'z': encoder_z,
            'd_loc': encoder_d,
            'p_y': encoder_p
        })

    # Decoder.
    decoder = Decoder(input_dim=args.z_dim,
                      hidden_dims=args.z_hidden_dims[::-1],
                      output_dim=count_matrix.shape[1])

    # Set up the pyro model for variational inference.
    model = RemoveBackgroundPyroModel(model_type=args.model,
                                      encoder=encoder,
                                      decoder=decoder,
                                      dataset_obj=dataset_obj,
                                      use_cuda=args.use_cuda)

    # Set up the optimizer.
    adam_args = {"lr": args.learning_rate}
    optimizer = ClippedAdam(adam_args)

    # Determine the loss function.
    if args.use_jit:
        loss_function = JitTraceEnum_ELBO(max_plate_nesting=1,
                                          strict_enumeration_warning=False)
    else:
        loss_function = TraceEnum_ELBO(max_plate_nesting=1)

    if args.model == "simple":
        if args.use_jit:
            loss_function = JitTrace_ELBO()
        else:
            loss_function = Trace_ELBO()

    # Set up the inference process.
    svi = SVI(model.model, model.guide, optimizer, loss=loss_function)

    # Load the dataset into DataLoaders.
    frac = args.training_fraction  # Fraction of barcodes to use for training
    batch_size = int(
        min(500, frac * dataset_obj.analyzed_barcode_inds.size / 2))
    train_loader, test_loader = \
        prep_data_for_training(dataset=count_matrix,
                               empty_drop_dataset=
                               dataset_obj.get_count_matrix_empties(),
                               random_state=dataset_obj.random,
                               batch_size=batch_size,
                               training_fraction=frac,
                               fraction_empties=args.fraction_empties,
                               shuffle=True,
                               use_cuda=args.use_cuda)

    # Run training.
    run_training(model,
                 svi,
                 train_loader,
                 test_loader,
                 epochs=args.epochs,
                 test_freq=10)

    return model