def fit(self, X: np.ndarray, **fit_params: Mapping[str, object]):
        """
        :param X: raw count array of UMIs. Must not be pre-processed, except for
                  optional filtering of bad cells/genes.
        :params fit_params: Additional parameters passed to the denoiser
        """

        rng = check_random_state(self.random_state)
        param_grid = ParameterGrid(self.param_grid)

        losses = defaultdict(list)

        for i in range(self.n_splits):
            umis_X, umis_Y = ut.split_molecules(X,
                                                self.data_split,
                                                self.overlap,
                                                random_state=rng)

            umis_X = self.transformation(umis_X)
            umis_Y = self.transformation(umis_Y)

            for params in param_grid:
                denoised_umis = self.denoiser(umis_X, **fit_params, **params)
                converted_denoised_umis = self.conversion(denoised_umis)
                losses[i].append(self.loss(converted_denoised_umis, umis_Y))

        losses = [np.mean(s) for s in zip(*losses.values())]

        best_index_ = np.argmin(losses)
        self.best_params_ = param_grid[best_index_]
        self.best_loss_ = losses[best_index_]

        self.cv_results_ = defaultdict(list)
        self.cv_results_["mcv_loss"] = losses

        for params in param_grid:
            for k in params:
                self.cv_results_[k].append(params[k])

        return self
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()

    run_group = parser.add_argument_group("run",
                                          description="Per-run parameters")
    run_group.add_argument("--seed", type=int, required=True)
    run_group.add_argument("--data_split",
                           type=float,
                           default=0.9,
                           help="Split for self-supervision")
    run_group.add_argument("--gpu", type=int, required=True)

    data_group = parser.add_argument_group(
        "data", description="Input and output parameters")
    data_group.add_argument("--dataset", type=pathlib.Path, required=True)
    data_group.add_argument("--output_dir", type=pathlib.Path, required=True)

    model_group = parser.add_argument_group("model",
                                            description="Model parameters")

    loss_group = model_group.add_mutually_exclusive_group(required=True)
    loss_group.add_argument(
        "--mse",
        action="store_const",
        const="mse",
        dest="loss",
        help="mean squared error",
    )
    loss_group.add_argument(
        "--pois",
        action="store_const",
        const="pois",
        dest="loss",
        help="poisson likelihood",
    )

    model_group.add_argument(
        "--layers",
        nargs="+",
        type=int,
        metavar="L",
        default=[128],
        help="Layers in the input/output networks",
    )
    model_group.add_argument(
        "--max_bottleneck",
        type=int,
        default=7,
        metavar="B",
        help="max bottleneck (log2)",
    )
    model_group.add_argument("--learning_rate",
                             type=float,
                             default=0.1,
                             metavar="LR",
                             help="learning rate")
    model_group.add_argument("--dropout",
                             type=float,
                             default=0.0,
                             metavar="P",
                             help="dropout probability")

    args = parser.parse_args()

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(logging.StreamHandler())

    logger.info(f"torch version {torch.__version__}")

    dataset_name = args.dataset.parent.name
    output_file = (
        args.output_dir /
        f"{dataset_name}_autoencoder_{args.loss}_{args.seed}.pickle")

    logger.info(f"writing output to {output_file}")

    seed = sum(map(ord, f"biohub_{args.seed}"))
    random_state = np.random.RandomState(seed)

    device = torch.device(f"cuda:{args.gpu}")

    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)

    with open(args.dataset, "rb") as f:
        true_means, true_counts, umis = pickle.load(f)

    n_features = umis.shape[-1]

    bottlenecks = [2**i for i in range(args.max_bottleneck + 1)]
    bottlenecks.extend(3 * b // 2 for b in bottlenecks[1:-1])
    bottlenecks.sort()

    logger.info(f"testing bottlenecks {bottlenecks}")

    if max(bottlenecks) > max(args.layers):
        raise ValueError(
            "Max bottleneck width is larger than your network layers")

    rec_loss = np.empty(len(bottlenecks), dtype=float)
    mcv_loss = np.empty_like(rec_loss)
    gt0_loss = np.empty_like(rec_loss)
    gt1_loss = np.empty_like(rec_loss)

    data_split, data_split_complement, overlap = ut.overlap_correction(
        args.data_split,
        umis.sum(1, keepdims=True) / true_counts)

    if args.loss == "mse":
        exp_means = ut.expected_sqrt(true_means * umis.sum(1, keepdims=True))
        exp_split_means = ut.expected_sqrt(true_means * data_split_complement *
                                           umis.sum(1, keepdims=True))

        exp_means = torch.from_numpy(exp_means).to(torch.float)
        exp_split_means = torch.from_numpy(exp_split_means).to(torch.float)

        loss_fn = nn.MSELoss()
        normalization = "sqrt"
        input_t = nn.Identity()
        eval0_fn = mse_loss_cpu
        eval1_fn = adjusted_mse_loss_cpu
    else:
        assert args.loss == "pois"
        exp_means = true_means * umis.sum(1, keepdims=True)
        exp_split_means = data_split_complement * exp_means

        exp_means = torch.from_numpy(exp_means).to(torch.float)
        exp_split_means = torch.from_numpy(exp_split_means).to(torch.float)

        loss_fn = nn.PoissonNLLLoss()
        normalization = "log1p"
        input_t = torch.log1p
        eval0_fn = poisson_nll_loss_cpu
        eval1_fn = adjusted_poisson_nll_loss_cpu

    model_factory = lambda bottleneck: CountAutoencoder(
        n_input=n_features,
        n_latent=bottleneck,
        layers=args.layers,
        use_cuda=True,
        dropout_rate=args.dropout,
    )

    optimizer_factory = lambda m: AggMo(m.parameters(),
                                        lr=args.learning_rate,
                                        betas=[0.0, 0.9, 0.99],
                                        weight_decay=1e-7)

    scheduler_kw = {
        "t_max": 256,
        "eta_min": args.learning_rate / 100.0,
        "factor": 1.0
    }

    train_losses = []
    val_losses = []

    full_train_losses = []
    full_val_losses = []

    batch_size = min(1024, umis.shape[0])

    with torch.cuda.device(device):
        umis_X, umis_Y = ut.split_molecules(umis, data_split, overlap,
                                            random_state)

        if args.loss == "mse":
            umis = np.sqrt(umis)
            umis_X = np.sqrt(umis_X)
            umis_Y = np.sqrt(umis_Y)

        umis = torch.from_numpy(umis).to(torch.float).to(device)
        umis_X = torch.from_numpy(umis_X).to(torch.float).to(device)
        umis_Y = torch.from_numpy(umis_Y).to(torch.float)
        data_split = torch.from_numpy(
            np.broadcast_to(data_split, (umis.shape[0], 1))).to(torch.float)
        data_split_complement = torch.from_numpy(
            np.broadcast_to(data_split_complement,
                            (umis.shape[0], 1))).to(torch.float)

        sample_indices = random_state.permutation(umis.size(0))
        n_train = int(0.875 * umis.size(0))

        train_dl, val_dl = mcv.train.split_dataset(
            umis_X,
            umis_Y,
            exp_split_means,
            data_split,
            data_split_complement,
            batch_size=batch_size,
            indices=sample_indices,
            n_train=n_train,
        )

        full_train_dl, full_val_dl = mcv.train.split_dataset(
            umis,
            exp_means,
            batch_size=batch_size,
            indices=sample_indices,
            n_train=n_train,
        )

        t0 = time.time()

        for j, b in enumerate(bottlenecks):
            logger.info(f"testing bottleneck width {b}")
            model = model_factory(b)
            optimizer = optimizer_factory(model)

            train_loss, val_loss = mcv.train.train_until_plateau(
                model,
                loss_fn,
                optimizer,
                train_dl,
                val_dl,
                input_t=input_t,
                min_cycles=3,
                threshold=0.001,
                scheduler_kw=scheduler_kw,
            )
            train_losses.append(train_loss)
            val_losses.append(val_loss)

            rec_loss[j] = train_loss[-1]
            mcv_loss[j] = mcv.train.evaluate_epoch(model,
                                                   eval1_fn,
                                                   train_dl,
                                                   input_t,
                                                   eval_i=[1, 3, 4])
            gt1_loss[j] = mcv.train.evaluate_epoch(model,
                                                   eval1_fn,
                                                   train_dl,
                                                   input_t,
                                                   eval_i=[2, 3, 4])

            model = model_factory(b)
            optimizer = optimizer_factory(model)

            full_train_loss, full_val_loss = mcv.train.train_until_plateau(
                model,
                loss_fn,
                optimizer,
                full_train_dl,
                full_val_dl,
                input_t=input_t,
                min_cycles=3,
                threshold=0.001,
                scheduler_kw=scheduler_kw,
            )

            full_train_losses.append(full_train_loss)
            full_val_losses.append(full_val_loss)

            logger.debug(f"finished {b} after {time.time() - t0} seconds")

            gt0_loss[j] = eval0_fn(model(input_t(umis)), exp_means)

    results = {
        "dataset": dataset_name,
        "method": "autoencoder",
        "loss": args.loss,
        "normalization": normalization,
        "param_range": bottlenecks,
        "rec_loss": rec_loss,
        "mcv_loss": mcv_loss,
        "gt0_loss": gt0_loss,
        "gt1_loss": gt1_loss,
        "train_losses": train_losses,
        "val_losses": val_losses,
        "full_train_losses": full_train_losses,
        "full_val_losses": full_val_losses,
    }

    with open(output_file, "wb") as out:
        pickle.dump(results, out)
def main():
    parser = argparse.ArgumentParser()

    run_group = parser.add_argument_group("run",
                                          description="Per-run parameters")
    run_group.add_argument("--seed", type=int, required=True)
    run_group.add_argument("--data_split",
                           type=float,
                           default=0.9,
                           help="Split for self-supervision")
    run_group.add_argument("--n_trials",
                           type=int,
                           default=10,
                           help="Number of times to resample")
    run_group.add_argument("--median_scale", action="store_true")

    data_group = parser.add_argument_group(
        "data", description="Input and output parameters")
    data_group.add_argument("--dataset", type=pathlib.Path, required=True)
    data_group.add_argument("--output_dir", type=pathlib.Path, required=True)
    data_group.add_argument("--genes",
                            type=int,
                            nargs="+",
                            required=True,
                            help="Genes to smooth (indices)")

    model_group = parser.add_argument_group(
        "model",
        description=
        "Model parameters. [max] or [min, max] or [min, max, interval]",
    )

    model_group.add_argument(
        "--neighbors",
        type=int,
        nargs="+",
        default=(1, 11),
        metavar="K",
        help="Number of neighbors in kNN graph",
    )
    model_group.add_argument(
        "--components",
        type=int,
        nargs="+",
        default=(5, 51, 5),
        metavar="PC",
        help="Maximum number of components to compute",
    )
    model_group.add_argument(
        "--time",
        type=int,
        nargs="+",
        default=(1, 6),
        metavar="T",
        help="Number of time steps for diffusion",
    )

    args = parser.parse_args()

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(logging.StreamHandler())

    dataset_name = args.dataset.parent.name
    output_file = args.output_dir / f"{dataset_name}_magic_mse_{args.seed}.pickle"

    logger.info(f"writing output to {output_file}")

    seed = sum(map(ord, f"biohub_{args.seed}"))
    random_state = np.random.RandomState(seed)

    with open(args.dataset, "rb") as f:
        true_means, true_counts, umis = pickle.load(f)

    k_range = np.arange(*args.neighbors)
    pc_range = np.arange(*args.components)
    t_range = np.arange(*args.time)

    rec_loss = dict()
    mcv_loss = dict()

    # run n_trials for self-supervised sweep
    for i in range(args.n_trials):
        umis_X, umis_Y = ut.split_molecules(umis, args.data_split, 0.0,
                                            random_state)

        if args.median_scale:
            median_count = np.median(umis.sum(axis=1))

            umis_X = umis_X / umis_X.sum(axis=1, keepdims=True) * median_count
            umis_Y = umis_Y / umis_Y.sum(axis=1, keepdims=True) * median_count
        else:
            umis_Y = umis_Y * args.data_split / (1 - args.data_split)

        for n_pcs in pc_range:
            for k in k_range:
                for t in t_range:
                    magic_op = magic.MAGIC(n_pca=n_pcs, verbose=0)
                    magic_op.set_params(knn=k, t=t)
                    denoised = magic_op.fit_transform(umis_X, genes=args.genes)
                    denoised = np.maximum(denoised, 0)

                    rec_loss[i, n_pcs, k,
                             t] = mean_squared_error(denoised,
                                                     umis_X[:, args.genes])
                    mcv_loss[i, n_pcs, k,
                             t] = mean_squared_error(denoised,
                                                     umis_Y[:, args.genes])

    results = {
        "dataset": dataset_name,
        "method": "magic",
        "loss": "mse",
        "normalization": "sqrt",
        "param_range": [pc_range, k_range, t_range],
        "rec_loss": rec_loss,
        "mcv_loss": mcv_loss,
    }

    with open(output_file, "wb") as out:
        pickle.dump(results, out)
def main():
    parser = argparse.ArgumentParser()

    run_group = parser.add_argument_group("run", description="Per-run parameters")
    run_group.add_argument("--seed", type=int, required=True)
    run_group.add_argument(
        "--data_split", type=float, default=0.9, help="Split for self-supervision"
    )
    run_group.add_argument(
        "--n_trials", type=int, default=10, help="Number of times to resample"
    )

    data_group = parser.add_argument_group(
        "data", description="Input and output parameters"
    )
    data_group.add_argument("--dataset", type=pathlib.Path, required=True)
    data_group.add_argument("--output_dir", type=pathlib.Path, required=True)

    model_group = parser.add_argument_group("model", description="Model parameters")
    model_group.add_argument(
        "--max_time", type=int, default=10, help="Maximum diffusion time"
    )

    loss_group = model_group.add_mutually_exclusive_group(required=True)
    loss_group.add_argument(
        "--mse",
        action="store_const",
        const="mse",
        dest="loss",
        help="mean-squared error",
    )
    loss_group.add_argument(
        "--pois",
        action="store_const",
        const="pois",
        dest="loss",
        help="poisson likelihood",
    )

    diff_op_group = model_group.add_argument_group(
        "diff_op", description="Parameters for computing the diffusion operator"
    )
    diff_op_group.add_argument(
        "--n_components",
        type=int,
        default=30,
        metavar="N",
        help="Number of components to compute",
    )
    diff_op_group.add_argument(
        "--n_neighbors",
        type=int,
        default=15,
        metavar="N",
        help="Neighbors for kNN graph",
    )
    diff_op_group.add_argument(
        "--tr_prob",
        type=float,
        default=0.5,
        help="Transition probability in lazy random walk",
    )

    args = parser.parse_args()

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(logging.StreamHandler())

    dataset_name = args.dataset.parent.name
    output_file = (
        args.output_dir / f"{dataset_name}_diffusion_{args.loss}_{args.seed}.pickle"
    )

    logger.info(f"writing output to {output_file}")

    seed = sum(map(ord, f"biohub_{args.seed}"))
    random_state = np.random.RandomState(seed)

    with open(args.dataset, "rb") as f:
        true_means, true_counts, umis = pickle.load(f)

    t_range = np.arange(args.max_time + 1)

    rec_loss = np.empty((args.n_trials, t_range.shape[0]), dtype=float)
    mcv_loss = np.empty_like(rec_loss)
    gt0_loss = np.empty(t_range.shape[0], dtype=float)
    gt1_loss = np.empty_like(rec_loss)

    data_split, data_split_complement, overlap = ut.overlap_correction(
        args.data_split, umis.sum(1, keepdims=True) / true_counts
    )

    if args.loss == "mse":
        exp_means = ut.expected_sqrt(true_means * umis.sum(1, keepdims=True))
        exp_split_means = ut.expected_sqrt(
            true_means * data_split_complement * umis.sum(1, keepdims=True)
        )

        loss = mean_squared_error
        normalization = "sqrt"
    else:
        assert args.loss == "pois"
        exp_means = true_means * umis.sum(1, keepdims=True)
        exp_split_means = data_split_complement * exp_means

        loss = lambda y_true, y_pred: (y_pred - y_true * np.log(y_pred + 1e-6)).mean()
        normalization = "none"

    # calculate gt loss for sweep using full data
    diff_op = compute_diff_op(
        umis, args.n_components, args.n_neighbors, args.tr_prob, random_state
    )

    if args.loss == "mse":
        diff = np.sqrt(umis)
    else:
        diff = umis.copy().astype(np.float)

    for t in t_range:
        gt0_loss[t] = loss(exp_means, diff)
        diff = diff_op.dot(diff)

    # run n_trials for self-supervised sweep
    for i in range(args.n_trials):
        umis_X, umis_Y = ut.split_molecules(umis, data_split, overlap, random_state)

        diff_op = compute_diff_op(
            umis_X, args.n_components, args.n_neighbors, args.tr_prob, random_state
        )

        if args.loss == "mse":
            umis_X = np.sqrt(umis_X)
            umis_Y = np.sqrt(umis_Y)

        diff_X = umis_X.copy().astype(np.float)

        # perform diffusion over the knn graph
        for t in t_range:
            if args.loss == "mse":
                conv_exp = ut.convert_expectations(
                    diff_X, data_split, data_split_complement
                )
            else:
                conv_exp = diff_X / data_split * data_split_complement

            rec_loss[i, t] = loss(umis_X, diff_X)
            mcv_loss[i, t] = loss(umis_Y, conv_exp)
            gt1_loss[i, t] = loss(exp_split_means, conv_exp)

            diff_X = diff_op.dot(diff_X)

    results = {
        "dataset": dataset_name,
        "method": "diffusion",
        "loss": args.loss,
        "normalization": normalization,
        "param_range": t_range,
        "rec_loss": rec_loss,
        "mcv_loss": mcv_loss,
        "gt0_loss": gt0_loss,
        "gt1_loss": gt1_loss,
    }

    with open(output_file, "wb") as out:
        pickle.dump(results, out)
def main():
    parser = argparse.ArgumentParser()

    run_group = parser.add_argument_group("run",
                                          description="Per-run parameters")
    run_group.add_argument("--seed", type=int, required=True)
    run_group.add_argument("--data_split",
                           type=float,
                           default=0.9,
                           help="Split for self-supervision")
    run_group.add_argument("--n_trials",
                           type=int,
                           default=10,
                           help="Number of times to resample")

    data_group = parser.add_argument_group(
        "data", description="Input and output parameters")
    data_group.add_argument("--dataset", type=pathlib.Path, required=True)
    data_group.add_argument("--output_dir", type=pathlib.Path, required=True)

    model_group = parser.add_argument_group("model",
                                            description="Model parameters")
    model_group.add_argument(
        "--max_components",
        type=int,
        default=50,
        metavar="K",
        help="Number of components to compute",
    )

    args = parser.parse_args()

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(logging.StreamHandler())

    dataset_name = args.dataset.parent.name
    output_file = args.output_dir / f"{dataset_name}_pca_mse_{args.seed}.pickle"

    logger.info(f"writing output to {output_file}")

    seed = sum(map(ord, f"biohub_{args.seed}"))
    random_state = np.random.RandomState(seed)

    with open(args.dataset, "rb") as f:
        true_means, true_counts, umis = pickle.load(f)

    k_range = np.arange(1, args.max_components + 1)

    rec_loss = np.empty((args.n_trials, k_range.shape[0]), dtype=float)
    mcv_loss = np.empty_like(rec_loss)
    gt0_loss = np.empty(k_range.shape[0], dtype=float)
    gt1_loss = np.empty_like(rec_loss)

    data_split, data_split_complement, overlap = ut.overlap_correction(
        args.data_split,
        umis.sum(1, keepdims=True) / true_counts)

    exp_means = ut.expected_sqrt(true_means * umis.sum(1, keepdims=True))
    exp_split_means = ut.expected_sqrt(true_means * data_split_complement *
                                       umis.sum(1, keepdims=True))

    # calculate gt loss for sweep using full data
    U, S, V = randomized_svd(np.sqrt(umis),
                             n_components=args.max_components,
                             random_state=random_state)

    for j, k in enumerate(k_range):
        pca_X = U[:, :k].dot(np.diag(S[:k])).dot(V[:k, :])
        gt0_loss[j] = mean_squared_error(exp_means, pca_X)

    # run n_trials for self-supervised sweep
    for i in range(args.n_trials):
        umis_X, umis_Y = ut.split_molecules(umis, data_split, overlap,
                                            random_state)

        umis_X = np.sqrt(umis_X)
        umis_Y = np.sqrt(umis_Y)

        U, S, V = randomized_svd(umis_X, n_components=args.max_components)
        US = U.dot(np.diag(S))

        for j, k in enumerate(k_range):
            pca_X = US[:, :k].dot(V[:k, :])
            conv_exp = ut.convert_expectations(pca_X, data_split,
                                               data_split_complement)

            rec_loss[i, j] = mean_squared_error(umis_X, pca_X)
            mcv_loss[i, j] = mean_squared_error(umis_Y, conv_exp)
            gt1_loss[i, j] = mean_squared_error(exp_split_means, conv_exp)

    results = {
        "dataset": dataset_name,
        "method": "pca",
        "loss": "mse",
        "normalization": "sqrt",
        "param_range": k_range,
        "rec_loss": rec_loss,
        "mcv_loss": mcv_loss,
        "gt0_loss": gt0_loss,
        "gt1_loss": gt1_loss,
    }

    with open(output_file, "wb") as out:
        pickle.dump(results, out)