Ejemplo n.º 1
0
def inc_gamma_loss(logr_logitp, x, c):
    logr, logitp = logr_logitp
    r = np.exp(logr)
    p = ilogit(logitp)
    f = r*np.log((1-p)) + x.sum()*np.log(p) - gammaln(r) + np.log(gammaincc(r+x, c/p)).sum() # gammaln(r + x) + gamma.logcdf(c/p, r+x)
    print(r,p, x, c, f)
    print(np.log(gammaincc(r+x, c/p)))
Ejemplo n.º 2
0
def test_hrt_discrete():
    from discrete import fit_classifier, MultinomialModel
    from utils import ilogit
    # Generate the ground truth
    N = 500
    X = (np.random.random(size=(N, 4)) <= ilogit(
        (np.random.normal(size=(N, 4)) + np.random.normal(size=(N, 1))) /
        2.)).astype(int)
    true_logits = (np.array([0.5, 1, 1.5])[np.newaxis, :] * X[:, 0:1] +
                   np.array([-2, 1, -0.5])[np.newaxis, :] * X[:, 1:2] +
                   X[:, 0:1] * X[:, 1:2] * np.array([-2, 1, 2])[np.newaxis, :])
    truth = np.exp(true_logits) / np.exp(true_logits).sum(axis=1,
                                                          keepdims=True)
    true_model = MultinomialModel(truth)

    # Sample some observations
    y = true_model.sample()

    # Fit the model
    print('Fitting predictor')
    split = int(np.round(X.shape[0] * 0.8))
    model = fit_classifier(X[:split], y[:split], nepochs=20)

    # Use the negative log-likelihood as the test statistic
    tstat = lambda X_test: -np.log(model.predict(X_test).pmf(y[split:])).mean()

    p_values = np.zeros(4)
    for j in range(X.shape[1]):
        p_values[j] = hrt(j,
                          tstat,
                          X[:split],
                          X[split:],
                          nperms=1000,
                          nbootstraps=10)['p_value']
        print('Feature {}: p={}'.format(j, p_values[j]))
Ejemplo n.º 3
0
def generate_synthetic_z(X,
                         nsignals=10,
                         signal_strength=1,
                         alt_strength=-2,
                         max_nonnull=0.4,
                         **kwargs):
    # Generate responses under the constraint that not too many are nonnull
    h = np.ones(X.shape[0])
    while h.mean() > max_nonnull:
        # Sample the true nonnull features
        signal_indices = np.random.choice(X.shape[1],
                                          replace=False,
                                          size=nsignals)

        # Random coefficients with an average signal strength
        beta = np.random.normal(0,
                                signal_strength /
                                np.sqrt(nsignals + nsignals // 2),
                                size=nsignals + nsignals // 2)
        np.set_printoptions(precision=2, suppress=True)

        # Quadratic interaction function
        logits = (X[:, signal_indices].dot(beta[:nsignals]) +
                  (X[:, signal_indices[:nsignals // 2]] *
                   X[:, signal_indices[nsignals // 2:nsignals // 2 * 2]]).dot(
                       beta[nsignals:]) - 1)

        # Get whether this was a signal or not
        r = np.random.random(size=logits.shape[0])
        h = ilogit(logits) >= r

        # If it was, get the z-score
        z = np.random.normal(alt_strength * h, 1).clip(-10, 10)

    return z, h, signal_indices, beta
Ejemplo n.º 4
0
def generate_synthetic_X(nsamples=1000, nfeatures=100, r=0.25, **kwargs):
    Sigma = np.array([
        np.exp(-np.abs(np.arange(nfeatures) - i)) * r for i in range(nfeatures)
    ])
    logits = np.random.multivariate_normal(np.zeros(nfeatures),
                                           Sigma,
                                           size=nsamples)
    X = (np.random.random(size=logits.shape) <= ilogit(logits)).astype(int)
    return X
Ejemplo n.º 5
0
 def log_likelihood_fn(proposal_beta, idx):
     if np.any(proposal_beta[:-1] > proposal_beta[1:]):
         return -np.inf
     present = Present[idx]
     y = Y[idx][present][:, np.newaxis]
     tau = ilogit(proposal_beta)[present][:, np.newaxis]
     grid = Lam_grid[idx]
     weights = Lam_weights[idx]
     c = C[idx]
     return np.log((poisson.pmf(y, grid * tau + c) * weights).clip(
         1e-10, np.inf).sum(axis=1)).sum()
Ejemplo n.º 6
0
def test_ess_Sigma():
    import matplotlib.pyplot as plt
    import seaborn as sns
    from utils import monotone_rejection_sampler
    N = 100
    ndoses = 9
    M = np.random.normal(0, 4, size=(N, ndoses))
    M.sort(axis=1)
    A, B, C = np.random.gamma(5, 10, size=N), np.random.gamma(
        1000, 10, size=N), np.random.gamma(10, 10,
                                           size=N)  # b is a scale param
    n_pos_ctrl = 40
    bandwidth, kernel_scale, noise_var = 1., 2., 0.05
    Sigma = np.array([
        kernel_scale * (np.exp(-0.5 *
                               (i - np.arange(ndoses))**2 / bandwidth**2))
        for i in np.arange(ndoses)
    ]) + noise_var * np.eye(ndoses)  # squared exponential kernel
    Beta = np.array([monotone_rejection_sampler(m, Sigma) for m in M])
    Tau = ilogit(Beta)
    Lam_y = np.array(
        [np.random.gamma(a, b, size=ndoses) for a, b in zip(A, B)])
    Lam_r = np.array(
        [np.random.gamma(a, b, size=n_pos_ctrl) for a, b in zip(A, B)])
    Y = np.random.poisson(Tau * Lam_y + C[:, np.newaxis])
    R = np.random.poisson(Lam_r + C[:, np.newaxis])

    # Add some missing dosages to predict
    for i in range(M.shape[0]):
        for j in range(M.shape[1]):
            if np.random.random() < 0.1:
                Y[i, j] = -1

    colors = ['blue', 'orange', 'green']
    [plt.plot(t, color=color) for t, color in zip(Tau, colors)]
    [
        plt.scatter(np.arange(ndoses)[y >= 0],
                    ((y[y >= 0] - c) / (r.mean() - c)).clip(0, 1),
                    color=color) for y, c, r, color in zip(Y, C, R, colors)
    ]
    plt.show()
    plt.close()

    Beta_samples, Sigma_samples, Loglikelihood_samples = posterior_ess_Sigma(
        Y, M, A, B, C, Sigma=Sigma)

    from utils import pretty_str
    print('Truth:')
    print(pretty_str(Sigma))
    print('')
    print('Bayes estimate:')
    print(pretty_str(Sigma_samples.mean(axis=0)))
    print('Last sample:')
    print(pretty_str(Sigma_samples[-1]))
 def log_likelihood(z, idx):
     expanded = len(z.shape) == 1
     if expanded:
         z = z[None]
     z = z[..., None]  # room for lambda grid
     lam = lam_grid[idx, None, None]  # room for z grid and multiple doses
     c = C[idx, None, None, None]
     w = weights[idx, None, None]
     y = Y[idx, None, :, None]
     result = np.nansum(np.log(
         (poisson.pmf(y,
                      ilogit(z) * lam + c) * w).clip(1e-10,
                                                     np.inf).sum(axis=-1)),
                        axis=-1)
     if expanded:
         return result[0]
     return result
    def train(self, model_fn,
                    bandwidth=2., kernel_scale=0.35, variance=0.02,
                    mvn_train_samples=5, mvn_validate_samples=105,
                    validation_samples=1000,
                    validation_burn=1000,
                    validation_mcmc_samples=1000,
                    validation_thin=1,
                    lr=3e-4, num_epochs=10, batch_size=100,
                    val_pct=0.1, nfolds=5, folds=None,
                    learning_rate_decay=0.9, weight_decay=0.,
                    clip=None, group_lasso_penalty=0.,
                    save_dir='tmp/',
                    checkpoint=False,
                    target_fold=None):
        print('\tFitting model using {} folds and training for {} epochs each'.format(nfolds, num_epochs))
        torch_Y = autograd.Variable(torch.FloatTensor(self.Y), requires_grad=False)
        torch_lam_grid = autograd.Variable(torch.FloatTensor(self.lam_grid), requires_grad=False)
        torch_lam_weights = autograd.Variable(torch.FloatTensor(self.lam_weights), requires_grad=False)
        torch_c = autograd.Variable(torch.FloatTensor(self.c[:,np.newaxis,np.newaxis]), requires_grad=False)
        torch_obs = autograd.Variable(torch.FloatTensor(self.obs_mask), requires_grad=False)
        torch_dose_idxs = [autograd.Variable(torch.LongTensor(
                                np.arange(d+(d**2 - d)//2, (d+1)+((d+1)**2 - (d+1))//2)), requires_grad=False)
                                for d in range(self.ndoses)]

        # Use a fixed kernel
        Sigma = np.array([kernel_scale*(np.exp(-0.5*(i - np.arange(self.ndoses))**2 / bandwidth**2)) for i in np.arange(self.ndoses)]) + variance*np.eye(self.ndoses) # squared exponential kernel
        L = np.linalg.cholesky(Sigma)[np.newaxis,np.newaxis,:,:]

        # Use a fixed set of noise draws for validation
        Z = np.random.normal(size=(self.Y_shape[0], mvn_validate_samples, self.ndoses, 1))
        validate_noise = autograd.Variable(torch.FloatTensor(np.matmul(L, Z)[:,:,:,0]), requires_grad=False)

        self.folds = folds if folds is not None else create_folds(self.Y_shape[0], nfolds)
        nfolds = len(self.folds)
        self.fold_validation_indices = []
        self.prior_mu = np.full(self.Y_shape, np.nan, dtype=float)
        self.prior_Sigma = np.zeros((nfolds, self.ndoses, self.ndoses))
        self.train_losses, self.val_losses = np.zeros((nfolds,num_epochs)), np.zeros((nfolds,num_epochs))
        self.epochs_per_fold = np.zeros(nfolds, dtype=int)
        self.models = [None for _ in range(nfolds)]
        for fold_idx, test_indices in enumerate(self.folds):
            # Create train/validate splits
            mask = np.ones(self.Y_shape[0], dtype=bool)
            mask[test_indices] = False
            indices = np.arange(self.Y_shape[0], dtype=int)[mask]
            np.random.shuffle(indices)
            train_cutoff = int(np.round(len(indices)*(1-val_pct)))
            train_indices = indices[:train_cutoff]
            validate_indices = indices[train_cutoff:]
            torch_test_indices = autograd.Variable(torch.LongTensor(test_indices), requires_grad=False)
            self.fold_validation_indices.append(validate_indices)

            # If we are only training one specific fold, skip all the rest
            if target_fold is not None and target_fold != fold_idx:
                continue

            if checkpoint:
                self.load_checkpoint(save_dir, fold_idx)

            if self.models[fold_idx] is None:
                self.models[fold_idx] = model_fn()

            model = self.models[fold_idx]

            # Setup the optimizers
            # optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
            optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3)
            for epoch in range(self.epochs_per_fold[fold_idx], num_epochs):
                print('\t\tFold {} Epoch {}'.format(fold_idx+1,epoch+1))
                train_loss = torch.Tensor([0])
                for batch_idx, batch in enumerate(batches(train_indices, batch_size)):
                    if batch_idx % 100 == 0:
                        print('\t\t\tBatch {}'.format(batch_idx))
                        sys.stdout.flush()

                    tidx = autograd.Variable(torch.LongTensor(batch), requires_grad=False)
                    Z = np.random.normal(size=(len(batch), mvn_train_samples, self.ndoses, 1))
                    noise = autograd.Variable(torch.FloatTensor(np.matmul(L, Z)[:,:,:,0]), requires_grad=False)

                    # Set the model to training mode
                    model.train()

                    # Reset the gradient
                    model.zero_grad()

                    # Run the model and get the prior predictions
                    mu = model(batch, tidx)

                    #### Calculate the loss as the negative log-likelihood of the data ####
                    # Get the MVN draw as mu + L.T.dot(Z)
                    beta = mu.view(-1,1,self.ndoses) + noise

                    # Logistic transform on the log-odds prior sample
                    tau = 1 / (1. + (-beta).exp())

                    # Poisson noise model for observations
                    rates = tau[:,:,:,None] * torch_lam_grid[tidx,None,:,:] + torch_c[tidx,None,:,:]
                    likelihoods = torch.distributions.Poisson(rates)

                    # Get log probabilities of the data and filter out the missing observations
                    loss = -(logsumexp(likelihoods.log_prob(torch_Y[tidx][:,None,:,None]) + torch_lam_weights[tidx][:,None,:,:], dim=-1).mean(dim=1) * torch_obs[tidx]).mean()

                    if group_lasso_penalty > 0:
                        loss += group_lasso_penalty * torch.norm(model.cell_line_features.weight, 2, 0).mean()

                    # Update the model
                    loss.backward()
                    if clip is not None:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
                        for p in model.parameters():
                            p.data.add_(-lr, p.grad.data)
                    else:
                        optimizer.step()

                    train_loss += loss.data

                validate_loss = torch.Tensor([0])
                for batch_idx, batch in enumerate(batches(validate_indices, batch_size, shuffle=False)):
                    if batch_idx % 100 == 0:
                        print('\t\t\tValidation Batch {}'.format(batch_idx))
                        sys.stdout.flush()
                    
                    tidx = autograd.Variable(torch.LongTensor(batch), requires_grad=False)
                    noise = validate_noise[tidx]

                    # Set the model to training mode
                    model.eval()

                    # Reset the gradient
                    model.zero_grad()

                    # Run the model and get the prior predictions
                    mu = model(batch, tidx)

                    #### Calculate the loss as the negative log-likelihood of the data ####
                    # Get the MVN draw as mu + L.T.dot(Z)
                    beta = mu.view(-1,1,self.ndoses) + noise

                    # Logistic transform on the log-odds prior sample
                    tau = 1 / (1. + (-beta).exp())

                    # Poisson noise model for observations
                    rates = tau[:,:,:,None] * torch_lam_grid[tidx,None,:,:] + torch_c[tidx,None,:,:]
                    likelihoods = torch.distributions.Poisson(rates)

                    # Get log probabilities of the data and filter out the missing observations
                    loss = -(logsumexp(likelihoods.log_prob(torch_Y[tidx][:,None,:,None]) + torch_lam_weights[tidx][:,None,:,:], dim=-1).mean(dim=1) * torch_obs[tidx]).sum()

                    validate_loss += loss.data

                self.train_losses[fold_idx, epoch] = train_loss.numpy() / float(len(train_indices))
                self.val_losses[fold_idx, epoch] = validate_loss.numpy() / float(len(validate_indices))

                # Adjust the learning rate down if the validation performance is bad
                scheduler.step(self.val_losses[fold_idx, epoch])

                # Check if we currently have the best held-out log-likelihood
                if epoch == 0 or np.argmin(self.val_losses[fold_idx, :epoch+1]) == epoch:
                    print('\t\t\tNew best score: {}'.format(self.val_losses[fold_idx,epoch]))
                    print('\t\t\tSaving test set results.')
                    # If so, use the current model on the test set
                    mu = model(test_indices, torch_test_indices)
                    self.prior_mu[test_indices] = mu.data.numpy()
                    self.save_fold(save_dir, fold_idx)
                
                cur_mu = self.prior_mu[test_indices]
                print('First 10 data points: {}'.format(test_indices[:10]))
                print('First 10 prior means:')
                print(pretty_str(ilogit(cur_mu[:10])))
                print('Prior mean ranges:')
                for dose in range(self.ndoses):
                    print('{}: {} [{}, {}]'.format(dose,
                                                   ilogit(cur_mu[:,dose].mean()),
                                                   np.percentile(ilogit(cur_mu[:,dose]), 5),
                                                   np.percentile(ilogit(cur_mu[:,dose]), 95)))
                print('Best model score: {} (epoch {})'.format(np.min(self.val_losses[fold_idx,:epoch+1]), np.argmin(self.val_losses[fold_idx, :epoch+1])+1))
                print('Current score: {}'.format(self.val_losses[fold_idx, epoch]))
                print('')

                self.epochs_per_fold[fold_idx] += 1
                
                # Update the save point if needed
                if checkpoint:
                    self.save_checkpoint(save_dir, fold_idx, model)
                    sys.stdout.flush()
                
            
            # Reload the best model
            tmp = model.cell_features
            self.load_fold(save_dir, fold_idx)
            self.models[fold_idx].cell_features = tmp

            print('Finished fold {}. Estimating covariance matrix using elliptical slice sampler with max {} samples.'.format(fold_idx+1, validation_samples))
            validate_subset = np.random.choice(validate_indices, validation_samples, replace=False) if len(validate_indices) > validation_samples else validate_indices
            tidx = autograd.Variable(torch.LongTensor(validate_subset), requires_grad=False)
                        
            # Set the model to training mode
            self.models[fold_idx].eval()

            # Reset the gradient
            self.models[fold_idx].zero_grad()

            # Run the model and get the prior predictions
            mu_validate = self.models[fold_idx](validate_subset, tidx).data.numpy()
            
            # Run the slice sampler to get the covariance and data log-likelihoods
            Y_validate = self.Y[validate_subset].astype(int)
            Y_validate[self.obs_mask[validate_subset] == 0] = -1
            (Beta_samples,
                Sigma_samples,
                Loglikelihood_samples) = posterior_ess_Sigma(Y_validate,
                                                             mu_validate,
                                                             self.a[validate_subset],
                                                             self.b[validate_subset],
                                                             self.c[validate_subset],
                                                             Sigma=Sigma,
                                                             nburn=validation_burn,
                                                             nsamples=validation_mcmc_samples,
                                                             nthin=validation_thin,
                                                             print_freq=1)

            # Save the result
            self.prior_Sigma[fold_idx] = Sigma_samples.mean(axis=0)
            print('Last sample:')
            print(pretty_str(Sigma_samples[-1]))
            print('Mean:')
            print(pretty_str(self.prior_Sigma[fold_idx]))

            if checkpoint:
                self.clean_checkpoint(save_dir, fold_idx)

        print('Finished training.')
        
        return {'train_losses': self.train_losses,
                'validation_losses': self.val_losses,
                'mu': self.prior_mu,
                'Sigma': self.prior_Sigma,
                'models': self.models}
    ebo = create_predictive_model(model_save_path, **dargs)

    # Load the binarized features and factorization
    print('Loading binarized features')
    df_binarized = pd.read_csv(args.genomic_features.replace(
        '.csv', '_binarized.csv'),
                               header=0,
                               index_col=0)
    W = np.load(
        args.genomic_features.replace('.csv', '_binarized_row_loading.npy'))
    V = np.load(
        args.genomic_features.replace('.csv', '_binarized_col_loading.npy'))

    # Get the features and probs
    X = df_binarized.values.astype(bool)
    X_probs = ilogit(W.dot(V.T))

    # Filter down to the cell lines we have features for
    indices = [
        i for i in range(X.shape[0]) if df_binarized.index[i] in ebo.cell_lines
    ]
    X = X[indices]
    X_probs = X_probs[indices]

    # Filter down further to the cell lines we have responses with this drug
    indices = np.arange(ebo.Y.shape[0])[np.any(
        ebo.obs_mask[:, args.drug].astype(bool), axis=1)]
    X = X[indices]
    X_probs = X_probs[indices]

    # Load the posteriors for this drug
Ejemplo n.º 10
0
def test_posterior_ess():
    import matplotlib.pyplot as plt
    import seaborn as sns
    from utils import monotone_rejection_sampler
    M = np.array([[-3, -2, -0.4, 0, 1, 1.1, 1.8, 3, 4],
                  [-7, -3, -0.1, 1.2, 1.5, 2.5, 3., 3.9, 4],
                  [-1, -0.5, 0, 1., 1.25, 2.5, 3.8, 3.9, 4]])
    N = M.shape[0]
    ndoses = M.shape[1]
    A, B, C = np.random.gamma(5, 10, size=N), np.random.gamma(
        1000, 10, size=N), np.random.gamma(10, 10,
                                           size=N)  # b is a scale param
    n_pos_ctrl = 40
    bandwidth, kernel_scale, noise_var = 1., 2., 0.05
    Sigma = np.array([
        kernel_scale * (np.exp(-0.5 *
                               (i - np.arange(ndoses))**2 / bandwidth**2))
        for i in np.arange(ndoses)
    ]) + noise_var * np.eye(ndoses)  # squared exponential kernel
    Beta = np.array([monotone_rejection_sampler(m, Sigma) for m in M])
    Tau = ilogit(Beta)
    Lam_y = np.array(
        [np.random.gamma(a, b, size=ndoses) for a, b in zip(A, B)])
    Lam_r = np.array(
        [np.random.gamma(a, b, size=n_pos_ctrl) for a, b in zip(A, B)])
    Y = np.random.poisson(Tau * Lam_y + C[:, np.newaxis])
    R = np.random.poisson(Lam_r + C[:, np.newaxis])

    # Add some missing dosages to predict
    for i in range(M.shape[0]):
        for j in range(M.shape[1]):
            if np.random.random() < 0.1:
                Y[i, j] = -1

    colors = ['blue', 'orange', 'green']
    [plt.plot(t, color=color) for t, color in zip(Tau, colors)]
    [
        plt.scatter(np.arange(M.shape[1])[y >= 0],
                    ((y[y >= 0] - c) / (r.mean() - c)).clip(0, 1),
                    color=color) for y, r, c, color in zip(Y, R, C, colors)
    ]
    plt.show()
    plt.close()

    Beta_hat = posterior_ess(Y, M, Sigma, A, B, C)
    Tau_hat = ilogit(Beta_hat)

    Beta_hat2 = posterior_ess(Y, M, Sigma, A, B, C)
    Tau_hat2 = ilogit(Beta_hat2)

    Beta_hat3 = posterior_ess(Y, M, Sigma, A, B, C)
    Tau_hat3 = ilogit(Beta_hat3)

    with sns.axes_style('white', {'legend.frameon': True}):
        plt.rc('font', weight='bold')
        plt.rc('grid', lw=3)
        plt.rc('lines', lw=2)
        plt.rc('axes', lw=2)

        colors = ['blue', 'orange', 'green']
        fig, axarr = plt.subplots(1, 3, sharex=True, sharey=True)
        for ax, y, t, t_hat, t_hat2, t_hat3, t_lower, t_upper, r, c, color in zip(
                axarr, Y, Tau, Tau_hat.mean(axis=0), Tau_hat2.mean(axis=0),
                Tau_hat3.mean(axis=0), np.percentile(Tau_hat, 5, axis=0),
                np.percentile(Tau_hat, 95, axis=0), R, C, colors):
            ax.scatter(np.arange(M.shape[1])[y >= 0],
                       ((y[y >= 0] - c) / (r.mean() - c)).clip(0, 1),
                       color=color)
            ax.plot(np.arange(M.shape[1]), t, color=color, lw=3, ls='--')
            ax.plot(np.arange(M.shape[1]), t_hat, color=color, lw=3)
            ax.plot(np.arange(M.shape[1]), t_hat2, color=color, lw=3)
            ax.plot(np.arange(M.shape[1]), t_hat3, color=color, lw=3)
            ax.fill_between(np.arange(M.shape[1]),
                            t_lower,
                            t_upper,
                            color=color,
                            alpha=0.5)
            ax.set_xlabel('Dosage level', fontsize=18, weight='bold')
            ax.set_ylabel('Survival percentage', fontsize=18, weight='bold')
    plt.show()
Ejemplo n.º 11
0
 def log_likelihood_fn(proposal_beta, dummy):
     if np.any(proposal_beta[:-1] > proposal_beta[1:]):
         return -np.inf
     tau = ilogit(proposal_beta)[present][:, np.newaxis]
     return np.log((poisson.pmf(y, grid * tau + c) * weights).clip(
         1e-10, np.inf).sum(axis=1)).sum()
def run():
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pylab as plt
    import os
    import argparse
    parser = argparse.ArgumentParser(
        description=
        'Estimate the dose-response covariance matrix on a per-drug basis.')

    # Experiment settings
    parser.add_argument(
        'name',
        default='gdsc',
        help='The project name. Will be prepended to plots and saved files.')
    parser.add_argument(
        '--drug',
        type=int,
        help=
        'If specified, fits only on a specific drug. This is useful for parallel/distributed training.'
    )
    parser.add_argument('--drug_responses',
                        default='data/raw_step3.csv',
                        help='The dataset file with all of the experiments.')
    parser.add_argument('--genomic_features',
                        default='data/gdsc_all_features.csv',
                        help='The file with the cell line features.')
    parser.add_argument(
        '--drug_details',
        default='data/gdsc_drug_details.csv',
        help=
        'The data file with all of the drug information (names, targets, etc).'
    )
    parser.add_argument('--plot_path',
                        default='plots',
                        help='The path where plots will be saved.')
    parser.add_argument('--save_path',
                        default='data',
                        help='The path where data and models will be saved.')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help='The pseudo-random number generator seed.')
    parser.add_argument(
        '--torch_threads',
        type=int,
        default=1,
        help='The number of threads that pytorch can use in a fold.')
    parser.add_argument('--no_fix',
                        action='store_true',
                        default=False,
                        help='Do not correct the dosages.')
    parser.add_argument('--verbose',
                        action='store_true',
                        help='If specified, prints progress to terminal.')
    parser.add_argument('--nburn',
                        type=int,
                        default=500,
                        help='Number of MCMC burn-in steps.')
    parser.add_argument('--nsamples',
                        type=int,
                        default=1500,
                        help='Number of MCMC steps to use.')
    parser.add_argument('--nthin',
                        type=int,
                        default=1,
                        help='Number of MCMC steps between sample steps.')
    parser.add_argument('--diagnostic',
                        action='store_true',
                        default=False,
                        help='Run a diagnostic setup to check convergence.')
    parser.add_argument('--ntrace',
                        type=int,
                        default=5,
                        help='Run a diagnostic setup to check convergence.')
    parser.add_argument('--ndiag',
                        type=int,
                        default=30,
                        help='Run a diagnostic setup to check convergence.')

    # Get the arguments from the command line
    args = parser.parse_args()
    dargs = vars(args)

    # Seed the random number generators so we get reproducible results
    np.random.seed(args.seed)

    print('Running posterior sampler with args:')
    print(args)
    print('Working on project: {}'.format(args.name))

    # Create the model directory
    model_save_path = os.path.join(args.save_path, args.name)
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

    # Load the predictor
    ebo = create_predictive_model(model_save_path, **dargs)
    ebo.load()

    # Generate MCMC diagnostics instead of saving results
    if args.diagnostic:
        diagnostics(args, dargs, ebo)
        return

    # Fit the posterior via MCMC
    drug_idx = args.drug
    indices, Beta, Sigma, loglike = beta_mcmc(ebo, drug_idx, **dargs)

    # Calculate the posterior AUC scores
    Tau = ilogit(Beta)
    AUC = (Tau.sum(axis=2) -
           0.5 * Tau[:, :, [0, -1]].sum(axis=2)) / (Tau.shape[2] - 1)

    posteriors_path = os.path.join(model_save_path, 'posteriors')
    if not os.path.exists(posteriors_path):
        os.makedirs(posteriors_path)
    np.save(os.path.join(posteriors_path, 'betas{}'.format(drug_idx)), Beta)
    np.save(os.path.join(posteriors_path, 'sigmas{}'.format(drug_idx)), Sigma)
    np.save(os.path.join(posteriors_path, 'taus{}'.format(drug_idx)), Tau)
    np.save(os.path.join(posteriors_path, 'aucs{}'.format(drug_idx)), AUC)

    ### Plot examples
    import matplotlib.pyplot as plt
    import seaborn as sns
    # Get the offsets and grids
    lam_grid = ebo.lam_grid[indices, drug_idx]
    weights = gamma.pdf(lam_grid,
                        ebo.A[indices, drug_idx, None],
                        scale=ebo.B[indices, drug_idx,
                                    None])  #.clip(1e-10, np.inf)
    weights /= weights.sum(axis=-1, keepdims=True)
    Y = ebo.Y[indices, drug_idx]
    C = ebo.C[indices, drug_idx]

    # Get the empirical Bayes predicted mean and back out the logits
    # tau_hat = ebo.mu[indices, drug_idx].clip(1e-4, 1-1e-4)
    tau_hat = ebo.predict_mu(ebo.X[indices])[:, drug_idx].clip(1e-4, 1 - 1e-4)
    Mu = np.log(tau_hat / (1 - tau_hat))

    Tau_unclipped = ((Y - C[..., None]) /
                     lam_grid[..., lam_grid.shape[-1] // 2, None])
    # for idx in range(30):
    #     plt.scatter(np.arange(Y.shape[1])[::-1], Tau_unclipped[idx], color='gray', label='Observed')
    #     plt.plot(np.arange(Y.shape[1])[::-1], ilogit(Mu[idx]), color='orange', label='Prior')
    #     plt.plot(np.arange(Y.shape[1])[::-1], ilogit(Beta[:,idx].mean(axis=0)), color='blue', label='Posterior')
    #     plt.fill_between(np.arange(Y.shape[1])[::-1],
    #                      ilogit(np.percentile(Beta[:,idx], 5, axis=0)),
    #                      ilogit(np.percentile(Beta[:,idx], 95, axis=0)),
    #                      alpha=0.3, color='blue')
    #     plt.legend(loc='lower left')
    #     plt.savefig('plots/posteriors-drug{}-sample{}.pdf'.format(drug_idx, idx), bbox_inches='tight')
    #     plt.close()

    # Fix bugs -- done and saved
    # df_sanger['DRUG_NAME'] = df_sanger['DRUG_NAME'].str.strip()
    # df_sanger[df_sanger['DRUG_NAME'] == 'BX-795'] = 'BX-796'
    # df_sanger[df_sanger['DRUG_NAME'] == 'SB505124'] = 'SB-505124'
    # df_sanger[df_sanger['DRUG_NAME'] == 'Lestaurtinib'] = 'Lestauritinib'

    # Get all the Sanger-processed AUC scores in a way we can handle it
    sanger_auc_path = os.path.join(args.save_path, 'sanger_auc.npy')
    if not os.path.exists(sanger_auc_path):
        import pandas as pd
        from collections import defaultdict
        df_sanger = pd.read_csv(os.path.join(args.save_path, 'gdsc_auc.csv'),
                                header=0,
                                index_col=0)
        cell_map, drug_map = defaultdict(lambda: -1), defaultdict(lambda: -1)
        for idx, c in enumerate(ebo.cell_lines):
            cell_map[c] = idx
        for idx, d in enumerate(ebo.drugs):
            drug_map[d] = idx
        AUC_sanger = np.full(ebo.Y.shape[:2], np.nan)
        for idx, row in df_sanger.iterrows():
            cidx, didx = cell_map[row['CELL_LINE_NAME']], drug_map[
                row['DRUG_NAME']]
            if cidx == -1 or didx == -1:
                continue
            AUC_sanger[cidx, didx] = row['AUC']
        np.save(sanger_auc_path, AUC_sanger)
    else:
        AUC_sanger = np.load(sanger_auc_path)

    import seaborn as sns
    with sns.axes_style('white', {'legend.frameon': True}):
        plt.rc('font', weight='bold')
        plt.rc('grid', lw=3)
        plt.rc('lines', lw=3)
        matplotlib.rcParams['pdf.fonttype'] = 42
        matplotlib.rcParams['ps.fonttype'] = 42
        overlap = ~np.isnan(AUC_sanger[indices, drug_idx])
        x = AUC_sanger[indices[overlap], drug_idx]
        y = AUC[:, overlap].mean(axis=0)
        plt.scatter(x, y, s=4)
        plt.plot([min(x.min(), y.min()), 1], [min(x.min(), y.min()), 1],
                 color='red',
                 lw=2)
        plt.xlabel('Original AUC', fontsize=18)
        plt.ylabel('Bayesian AUC', fontsize=18)
        plt.savefig('plots/auc-compare{}.pdf'.format(drug_idx),
                    bbox_inches='tight')
        plt.close()
def diagnostics(args, dargs, ebo):
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pylab as plt

    # Get the index of the target drug
    drug_idx = args.drug

    # Save all samples including burn-in and thinning
    nsampels, nthin, nburn = args.nsamples, args.nthin, args.nburn
    args.nsamples = args.nsamples * args.nthin + args.nburn
    args.nburn = 0
    args.nthin = 1

    # Filter down the drugs to a few randomly chosen subsets
    indices = np.random.choice(np.arange(ebo.Y.shape[0])[np.any(
        ebo.obs_mask[:, drug_idx].astype(bool), axis=1)],
                               size=args.ndiag,
                               replace=False)
    mask = np.zeros(ebo.obs_mask.shape)
    mask[indices, drug_idx] = ebo.obs_mask[indices, drug_idx]
    ebo.obs_mask = mask

    # Fit the posterior via MCMC
    indices, Beta, Sigma, loglike = beta_mcmc(ebo, drug_idx, **dargs)

    # Get all the different survival rates
    Tau = ilogit(Beta)

    # Simple trace plot
    import seaborn as sns
    with sns.axes_style('white', {'legend.frameon': True}):
        plt.rc('font', weight='bold')
        plt.rc('grid', lw=3)
        plt.rc('lines', lw=1)
        matplotlib.rcParams['pdf.fonttype'] = 42
        matplotlib.rcParams['ps.fonttype'] = 42
        plt.plot(Tau[:, :args.ntrace, 0].reshape((Tau.shape[0], -1)))
        plt.xlabel('MCMC iteration', fontsize=18)
        plt.ylabel('Dose-response values ($\\tau$)', fontsize=18)
        plt.savefig(f'{args.plot_path}/trace-{args.drug}.pdf',
                    bbox_inches='tight')
        plt.close()

    # Filter Tau using the burn-in and thinning settings
    Tau = Tau[nburn:]
    Tau = Tau[::nthin]

    # Coverage stats
    credible_intervals = np.array([.50, .75, .85, .90, .95, .99])

    # Get the offsets and grids
    Y = ebo.Y[indices, drug_idx]
    C = ebo.C[indices, drug_idx]
    lams = gamma.rvs(ebo.A[indices, drug_idx],
                     scale=ebo.B[indices, drug_idx],
                     size=(100, Tau.shape[0]) + ebo.A[indices, drug_idx].shape)
    print(Y.shape, C.shape, lams.shape, Tau.shape)
    # Calculate the upper and lower credible interval bands via MC
    Y_samples = poisson.rvs(C[None, None, :, None] + lams[..., None] *
                            Tau[None]).reshape((-1, ) + Tau.shape[-2:])
    Y_upper = np.zeros((len(credible_intervals), ) + Y.shape)
    Y_lower = np.zeros((len(credible_intervals), ) + Y.shape)
    print('Y samples', Y_samples.shape, 'Y_upper', Y_upper.shape)
    for ci_idx, interval in enumerate(credible_intervals):
        Y_upper[ci_idx] = np.percentile(Y_samples,
                                        100 - (1 - interval) / 2 * 100,
                                        axis=0)
        Y_lower[ci_idx] = np.percentile(Y_samples, (1 - interval) / 2 * 100,
                                        axis=0)

    # Check for coverage rates
    coverage = np.array([
        np.nanmean((Y_lower[i] <= Y) & (Y_upper[i] >= Y))
        for i in range(len(credible_intervals))
    ])
    print(coverage)

    import seaborn as sns
    with sns.axes_style('white', {'legend.frameon': True}):
        plt.rc('font', weight='bold')
        plt.rc('grid', lw=3)
        plt.rc('lines', lw=3)
        matplotlib.rcParams['pdf.fonttype'] = 42
        matplotlib.rcParams['ps.fonttype'] = 42
        plt.plot(credible_intervals * 100, coverage * 100, color='blue')
        plt.plot(credible_intervals * 100,
                 credible_intervals * 100,
                 color='black')
        plt.xlabel('Posterior credible interval', fontsize=18)
        plt.ylabel('Coverage', fontsize=18)
        plt.savefig(f'{args.plot_path}/coverage-{args.drug}.pdf',
                    bbox_inches='tight')
        plt.close()