Ejemplo n.º 1
0
def feature_importance(crispr_model, x_len, xs, y, unique_test_index):

    logger.debug("Getting features ranks")
    names = []
    names += ["for_" + str(i) for i in range(x_len)]
    names += config.extra_categorical_features + config.extra_numerical_features
    ranker = feature_imp.InputPerturbationRank(names)
    feature_ranks = ranker.rank(20, y[unique_test_index], crispr_model,
                                [data[unique_test_index] for data in xs])
    feature_ranks_df = pd.DataFrame(feature_ranks)
    feature_ranks_df.to_csv(config.feature_importance_path, index=False)
    logger.debug("Get features ranks successfully")
Ejemplo n.º 2
0

    elif attention_setting.output_FF_layers[-1] == 1:
        crispr_model = attention_model.get_OT_model(data_pre.feature_length_map)
        regressor_training(crispr_model, X, y, cv_splitter)
    else:
        crispr_model_classifier = attention_model.get_OT_model(data_pre.feature_length_map, classifier=True)
        best_crispr_model, train_index = classifier_training(crispr_model_classifier, X, y_binary, cv_splitter)
    if config.check_feature_importance:

        logger.debug("Getting features ranks")
        names = []
        names += ["src_" + str(i) for i in range(data_pre.feature_length_map[0][1])]
        if data_pre.feature_length_map[1] is not None: names += ["trg_" + str(i)
                                                                 for i in range(
                data_pre.feature_length_map[1][1] - data_pre.feature_length_map[1][0])]
        if data_pre.feature_length_map[
            2] is not None: names += config.extra_categorical_features + config.extra_numerical_features
        ranker = feature_imp.InputPerturbationRank(names)
        feature_ranks = ranker.rank(2, y_binary[train_index], best_crispr_model,
                                    [X[train_index, :]])
        feature_ranks_df = pd.DataFrame(feature_ranks)
        feature_ranks_df.to_csv(config.feature_importance_path, index=False)
        logger.debug("Get features ranks successfully")






Ejemplo n.º 3
0
def run():

    data_pre = OT_crispr_attn.data_preparer()
    print(data_pre.get_crispr_preview())
    X = data_pre.prepare_x()
    y = data_pre.get_labels()
    data_pre.persist_data()
    print(X.head())
    print(data_pre.feature_length_map)
    train_index, test_index = data_pre.train_test_split()
    if config.log2fc_filter:
        train_filter_1 = (y.loc[train_index, :] > 10)
        y[train_filter_1] = 10
        train_filter_2 = (y.loc[train_index, :] < -10)
        y[train_filter_2] = -10
        # train_filter = (y.loc[train_index, :] > 10)
        # train_index = train_index[train_filter.iloc[:,0]]
    print(train_index, test_index)
    torch.manual_seed(0)
    # X = X.values.astype(np.float32)
    # y = y.values.astype(np.float32)
    # y_binary = (y > np.quantile(y, 0.8)).astype(int).reshape(-1, )

    #y[y.loc[train_index, config.y] < 0] = 0
    # essen_filter = list(data_pre.crispr[data_pre.crispr['log2fc'] > 0].index)
    # train_index = list(set(train_index).intersection(essen_filter))
    # test_index = list(set(test_index).intersection(essen_filter))
    std_scaler = StandardScaler()
    m_m = MinMaxScaler((0, 100))
    if config.y_transform:
        std_scaler.fit(y.loc[train_index, :])
        new_y = std_scaler.transform(y) * 100
        y = pd.DataFrame(new_y, columns=y.columns, index=y.index)
        m_m.fit(y.loc[train_index, :])
        new_y = m_m.transform(y)
        y = pd.DataFrame(new_y, columns=y.columns, index=y.index)

    if config.test_cellline is not None:
        test_cellline_index = data_pre.crispr[data_pre.crispr['cellline'].isin(
            config.test_cellline)].index
        test_index = test_cellline_index.intersection(test_index)

    if config.train_cellline is not None:
        train_cellline_index = data_pre.crispr[
            data_pre.crispr['cellline'].isin(config.train_cellline)].index
        train_index = train_cellline_index.intersection(train_index)

    logger.debug("training data amounts: %s, testing data amounts: %s" %
                 (len(train_index), len(test_index)))
    x_train, x_test, y_train, y_test = \
        X.loc[train_index, :], X.loc[test_index, :], \
        y.loc[train_index, :], y.loc[test_index, :]

    _, unique_train_index = np.unique(pd.concat([x_train, y_train], axis=1),
                                      return_index=True,
                                      axis=0)
    _, unique_test_index = np.unique(pd.concat([x_test, y_test], axis=1),
                                     return_index=True,
                                     axis=0)
    logger.debug(
        "after deduplication, training data amounts: %s, testing data amounts: %s"
        % (len(unique_train_index), len(unique_test_index)))
    train_index = train_index[unique_train_index]
    test_index = test_index[unique_test_index]
    x_concat = pd.concat([X, y], axis=1)
    _, unique_index = np.unique(x_concat, return_index=True, axis=0)
    logger.debug("{0!r}, {1!r}".format(
        (len(x_concat.loc[train_index, :].drop_duplicates())),
        str(len(x_concat.loc[train_index, :]))))
    logger.debug("Splitted dataset successfully")

    train_eval_index_list = list(train_index)
    random.shuffle(train_eval_index_list)
    sep = int(len(train_eval_index_list) * 0.9)
    train_index_list = train_eval_index_list[:sep]
    eval_index_list = train_eval_index_list[sep:]
    partition = {
        'train': train_eval_index_list,
        'eval': list(test_index),
        'test': list(test_index)
    }

    labels = {
        key: value
        for key, value in zip(range(len(y)), list(y.values.reshape(-1)))
    }

    train_params = {'batch_size': config.batch_size, 'shuffle': True}
    eval_params = {'batch_size': len(test_index), 'shuffle': False}
    test_params = {'batch_size': len(test_index), 'shuffle': False}

    logger.debug("Preparing datasets ... ")
    training_set = my_data.MyDataset(partition['train'], labels)
    training_generator = data.DataLoader(training_set, **train_params)

    train_bg_params = {
        'batch_size': len(train_eval_index_list) // 6 + 5,
        'shuffle': False
    }
    training_bg_set = my_data.MyDataset(partition['train'], labels)
    training_bg_generator = data.DataLoader(training_bg_set, **train_bg_params)

    validation_set = my_data.MyDataset(partition['eval'], labels)
    validation_generator = data.DataLoader(validation_set, **eval_params)

    test_set = my_data.MyDataset(partition['test'], labels)
    test_generator = data.DataLoader(test_set, **test_params)

    logger.debug("I might need to augment data")

    logger.debug("Building the scaled dot product attention model")
    for_input_len = data_pre.feature_length_map[0][
        1] - data_pre.feature_length_map[0][0]
    extra_input_len = 0 if not data_pre.feature_length_map[2] \
        else data_pre.feature_length_map[2][1] - data_pre.feature_length_map[2][0]

    crispr_model = attention_model.get_OT_model(data_pre.feature_length_map)
    best_crispr_model = attention_model.get_OT_model(
        data_pre.feature_length_map)
    crispr_model.to(device2)
    best_crispr_model.to(device2)
    # crispr_model = attention_model.get_model(d_input=for_input_len)
    # best_crispr_model = attention_model.get_model(d_input=for_input_len)
    # #crispr_model = attention_model.get_OT_model(data_pre.feature_length_map, classifier=True)
    # crispr_model.to(device2)
    # best_crispr_model.to(device2)
    best_cv_spearman_score = 0

    if config.retraining:

        logger.debug("I need to load a old trained model")
        crispr_model = load(config.retraining_model)
        crispr_model.load_state_dict(load(config.retraining_model_state))
        crispr_model.to(device2)
        logger.debug("I might need to freeze some of the weights")

    logger.debug("Built the Crispr model successfully")

    optimizer = torch.optim.Adam(crispr_model.parameters(),
                                 lr=config.start_lr,
                                 weight_decay=config.lr_decay,
                                 betas=(0.9, 0.98),
                                 eps=1e-9)

    try:
        if config.training:

            logger.debug("Training the model")
            mse_visualizer = torch_visual.VisTorch(env_name='MSE')
            pearson_visualizer = torch_visual.VisTorch(env_name='Pearson')
            spearman_visualizer = torch_visual.VisTorch(env_name='Spearman')

            for epoch in range(config.n_epochs):

                start = time()
                cur_epoch_train_loss = []
                train_total_loss = 0
                i = 0

                # Training
                for local_batch, local_labels in training_generator:
                    i += 1
                    # Transfer to GPU
                    local_batch, local_labels = local_batch.float().to(
                        device2), local_labels.float().to(device2)
                    # seq_local_batch = local_batch.narrow(dim=1, start=0, length=for_input_len).long()
                    # extra_local_batch = local_batch.narrow(dim=1, start=for_input_len, length=extra_input_len)
                    # Model computations
                    preds = crispr_model(local_batch).contiguous().view(-1)
                    ys = local_labels.contiguous().view(-1)
                    optimizer.zero_grad()
                    assert preds.size(-1) == ys.size(-1)
                    #loss = F.nll_loss(preds, ys)
                    crispr_model.train()
                    loss = F.mse_loss(preds, ys)
                    loss.backward()
                    optimizer.step()

                    train_total_loss += loss.item()

                    n_iter = 2
                    if i % n_iter == 0:
                        sample_size = len(train_index)
                        p = int(100 * i * config.batch_size / sample_size)
                        avg_loss = train_total_loss / n_iter
                        if config.y_inverse_transform:
                            avg_loss = \
                            std_scaler.inverse_transform(np.array(avg_loss / 100).reshape(-1, 1)).reshape(-1)[0]
                        logger.debug("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                                                 ((time() - start) // 60, epoch, "".join('#' * (p // 5)),
                                                  "".join(' ' * (20 - (p // 5))), p, avg_loss))
                        train_total_loss = 0
                        cur_epoch_train_loss.append(avg_loss)

                ### Evaluation
                val_i = 0
                val_total_loss = 0
                val_loss = []
                val_preds = []
                val_ys = []
                val_pearson = 0
                val_spearman = 0

                with torch.set_grad_enabled(False):

                    crispr_model.eval()
                    for local_batch, local_labels in training_bg_generator:
                        val_i += 1
                        local_labels_on_cpu = np.array(local_labels).reshape(
                            -1)
                        sample_size = local_labels_on_cpu.shape[-1]
                        local_labels_on_cpu = local_labels_on_cpu[:sample_size]
                        # Transfer to GPU
                        local_batch, local_labels = local_batch.float().to(
                            device2), local_labels.float().to(device2)
                        # seq_local_batch = local_batch.narrow(dim=1, start=0, length=for_input_len).long()
                        # extra_local_batch = local_batch.narrow(dim=1, start=for_input_len, length=extra_input_len)
                        preds = crispr_model(local_batch).contiguous().view(-1)
                        assert preds.size(-1) == local_labels.size(-1)
                        prediction_on_cpu = preds.cpu().numpy().reshape(-1)
                        # mean_prediction_on_cpu = np.mean([prediction_on_cpu[:sample_size],
                        #                                   prediction_on_cpu[sample_size:]], axis=0)
                        mean_prediction_on_cpu = prediction_on_cpu[:
                                                                   sample_size]
                        if config.y_inverse_transform:
                            local_labels_on_cpu, mean_prediction_on_cpu = \
                                std_scaler.inverse_transform(local_labels_on_cpu.reshape(-1, 1) / 100), \
                                std_scaler.inverse_transform(mean_prediction_on_cpu.reshape(-1, 1) / 100)
                        loss = mean_squared_error(local_labels_on_cpu,
                                                  mean_prediction_on_cpu)
                        val_preds.append(mean_prediction_on_cpu)
                        val_ys.append(local_labels_on_cpu)
                        val_total_loss += loss

                        n_iter = 1
                        if val_i % n_iter == 0:
                            avg_loss = val_total_loss / n_iter
                            val_loss.append(avg_loss)
                            val_total_loss = 0

                    mean_prediction_on_cpu = np.concatenate(tuple(val_preds))
                    local_labels_on_cpu = np.concatenate(tuple(val_ys))
                    val_pearson = pearsonr(mean_prediction_on_cpu.reshape(-1),
                                           local_labels_on_cpu.reshape(-1))[0]
                    val_spearman = spearmanr(
                        mean_prediction_on_cpu.reshape(-1),
                        local_labels_on_cpu.reshape(-1))[0]

                logger.debug(
                    "Validation mse is {0}, Validation pearson correlation is {1!r} and Validation "
                    "spearman correlation is {2!r}".format(
                        np.mean(val_loss), val_pearson, val_spearman))
                mse_visualizer.plot_loss(epoch,
                                         np.mean(cur_epoch_train_loss),
                                         np.mean(val_loss),
                                         loss_type='mse')
                pearson_visualizer.plot_loss(epoch,
                                             val_pearson,
                                             loss_type='pearson_loss',
                                             ytickmin=0,
                                             ytickmax=1)
                spearman_visualizer.plot_loss(epoch,
                                              val_spearman,
                                              loss_type='spearman_loss',
                                              ytickmin=0,
                                              ytickmax=1)

                ### Evaluation
                val_i = 0
                val_total_loss = 0
                val_loss = []
                val_preds = []
                val_ys = []
                val_pearson = 0
                val_spearman = 0

                with torch.set_grad_enabled(False):

                    crispr_model.eval()
                    for local_batch, local_labels in validation_generator:
                        val_i += 1
                        local_labels_on_cpu = np.array(local_labels).reshape(
                            -1)
                        sample_size = local_labels_on_cpu.shape[-1]
                        local_labels_on_cpu = local_labels_on_cpu[:sample_size]
                        # Transfer to GPU
                        local_batch, local_labels = local_batch.float().to(
                            device2), local_labels.float().to(device2)
                        # seq_local_batch = local_batch.narrow(dim=1, start=0, length=for_input_len).long()
                        # extra_local_batch = local_batch.narrow(dim=1, start=for_input_len, length=extra_input_len)
                        preds = crispr_model(local_batch).contiguous().view(-1)
                        assert preds.size(-1) == local_labels.size(-1)
                        prediction_on_cpu = preds.cpu().numpy().reshape(-1)
                        # mean_prediction_on_cpu = np.mean([prediction_on_cpu[:sample_size],
                        #                                   prediction_on_cpu[sample_size:]], axis=0)
                        mean_prediction_on_cpu = prediction_on_cpu[:
                                                                   sample_size]
                        if config.y_inverse_transform:
                            local_labels_on_cpu, mean_prediction_on_cpu = \
                                std_scaler.inverse_transform(local_labels_on_cpu.reshape(-1, 1) / 100), \
                                std_scaler.inverse_transform(mean_prediction_on_cpu.reshape(-1, 1) / 100)
                        loss = mean_squared_error(local_labels_on_cpu,
                                                  mean_prediction_on_cpu)
                        val_preds.append(mean_prediction_on_cpu)
                        val_ys.append(local_labels_on_cpu)
                        val_total_loss += loss

                        n_iter = 1
                        if val_i % n_iter == 0:
                            avg_loss = val_total_loss / n_iter
                            val_loss.append(avg_loss)
                            val_total_loss = 0

                    mean_prediction_on_cpu = np.concatenate(tuple(val_preds))
                    local_labels_on_cpu = np.concatenate(tuple(val_ys))
                    val_pearson = pearsonr(mean_prediction_on_cpu.reshape(-1),
                                           local_labels_on_cpu.reshape(-1))[0]
                    val_spearman = spearmanr(
                        mean_prediction_on_cpu.reshape(-1),
                        local_labels_on_cpu.reshape(-1))[0]

                    if best_cv_spearman_score < val_spearman:
                        best_cv_spearman_score = val_spearman
                        best_crispr_model.load_state_dict(
                            crispr_model.state_dict())

                logger.debug(
                    "Test mse is {0}, Test pearson correlation is {1!r} and Test "
                    "spearman correlation is {2!r}".format(
                        np.mean(val_loss), val_pearson, val_spearman))
                mse_visualizer.plot_loss(epoch,
                                         np.mean(cur_epoch_train_loss),
                                         np.mean(val_loss),
                                         loss_type='mse')
                pearson_visualizer.plot_loss(epoch,
                                             val_pearson,
                                             loss_type='pearson_loss',
                                             ytickmin=0,
                                             ytickmax=1)
                spearman_visualizer.plot_loss(epoch,
                                              val_spearman,
                                              loss_type='spearman_loss',
                                              ytickmin=0,
                                              ytickmax=1)

            logger.debug("Saving training history")

            logger.debug("Saved training history successfully")

            logger.debug("Trained crispr model successfully")

        else:
            logger.debug("loading in old model")

            logger.debug("Load in model successfully")

    except KeyboardInterrupt as e:

        logger.debug("Loading model")
        logger.debug("loading some intermediate step's model")
        logger.debug("Load in model successfully")

    logger.debug("Persisting model")
    # serialize weights to HDF5
    save(best_crispr_model, config.hdf5_path)
    save(best_crispr_model.state_dict(), config.hdf5_path_state)
    logger.debug("Saved model to disk")

    save_output = []
    if 'essentiality' in config.extra_numerical_features:
        save_output.append(data_pre.crispr.loc[test_index, 'essentiality'])
    test_model(best_crispr_model, test_generator, save_output, std_scaler)

    if config.check_feature_importance:

        logger.debug("Getting features ranks")
        names = []
        names += [
            "src_" + str(i) for i in range(data_pre.feature_length_map[0][1])
        ]
        if data_pre.feature_length_map[1] is not None:
            names += [
                "trg_" + str(i)
                for i in range(data_pre.feature_length_map[1][1] -
                               data_pre.feature_length_map[1][0])
            ]
        if data_pre.feature_length_map[2] is not None:
            names += config.extra_categorical_features + config.extra_numerical_features
        ranker = feature_imp.InputPerturbationRank(names)
        feature_ranks = ranker.rank(
            2,
            y.loc[train_eval_index_list, :].values,
            best_crispr_model,
            [torch.FloatTensor(X.loc[train_eval_index_list, :].values)],
            torch=True)
        feature_ranks_df = pd.DataFrame(feature_ranks)
        feature_ranks_df.to_csv(config.feature_importance_path, index=False)
        logger.debug("Get features ranks successfully")
Ejemplo n.º 4
0
def run():

    logger.debug("Reading in the crispr dataset %s" % config.input_dataset)
    crispr = pd.read_csv(config.input_dataset)
    crispr['PAM'] = crispr['sequence'].str[-3:]
    if config.log_cen:
        crispr['essentiality'] = np.log(crispr['essentiality'] * 100 + 1)
    if config.with_pam:
        pam_code = 8
    else:
        pam_code = 0
    # scale_features
    process_features.scale_features(crispr)
    process_features.scale_output(crispr)
    logger.debug("Read in data successfully")

    logger.debug("Transforming data")
    X_for = crispr.loc[:, 'sequence'].apply(
        lambda seq: utils.split_seqs(seq[:config.seq_len]))
    X_rev = crispr.loc[:, 'sequence'].apply(
        lambda seq: utils.split_seqs(seq[config.seq_len - 1::-1]))
    X_cnn = crispr.loc[:, 'sequence'].apply(
        lambda seq: utils.split_seqs(seq[:config.seq_len], nt=1))
    X = pd.concat([X_for, X_rev, X_cnn], axis=1)
    logger.debug("Get sequence sucessfully")
    off_target_X = pd.DataFrame(np.empty(shape=[X_for.shape[0], 0]))
    # off_target_X = crispr.loc[:, 'sequence'].apply(lambda seq: utils.map_to_matrix(seq, 1, 22))
    # y = pd.DataFrame(np.abs(crispr[config.y].copy()) * 10)
    y = pd.DataFrame(crispr[config.y].copy() * 8)
    logger.debug("Transformed data successfully")

    logger.debug(
        "Starting to prepare for splitting dataset to training dataset and testing dataset based on genes"
    )
    logger.debug("Generating groups based on gene names")
    if config.group:
        crispr.loc[:, "group"] = pd.Categorical(crispr.loc[:, config.group])
    logger.debug("Generated groups information successfully")

    logger.debug("Splitting dataset")
    if os.path.exists(config.train_index) and os.path.exists(
            config.test_index):
        train_index = pickle.load(open(config.train_index, "rb"))
        test_index = pickle.load(open(config.test_index, "rb"))
    else:
        train_test_split = getattr(process_features,
                                   config.split_method + "_split",
                                   process_features.regular_split)
        train_index, test_index = train_test_split(crispr,
                                                   group_col=config.group_col,
                                                   n_split=max(
                                                       len(crispr) / 100, 10),
                                                   rd_state=7)

        with open(config.train_index, 'wb') as train_file:
            pickle.dump(train_index, train_file)
        with open(config.test_index, 'wb') as test_file:
            pickle.dump(test_index, test_file)

    if config.test_cellline:
        test_cellline_index = crispr[crispr['cellline'] ==
                                     config.test_cellline].index
        test_index = test_cellline_index.intersection(test_index)

    test_index_list = [
        x.index
        for _, x in crispr.loc[test_index, :].reset_index().groupby('group')
        if len(x)
    ] if config.test_method == 'group' else []
    logger.debug("Splitted data successfully")

    logger.debug("training data amounts: %s, testing data amounts: %s" %
                 (len(train_index), len(test_index)))
    x_train, x_test, y_train, y_test, off_target_X_train, off_target_X_test = \
                                       X.loc[train_index, :], X.loc[test_index, :], \
                                       y.loc[train_index, :], y.loc[test_index, :], \
                                       off_target_X.loc[train_index, :], off_target_X.loc[test_index, :]

    _, unique_train_index = np.unique(pd.concat([x_train, y_train], axis=1),
                                      return_index=True,
                                      axis=0)
    _, unique_test_index = np.unique(pd.concat([x_test, y_test], axis=1),
                                     return_index=True,
                                     axis=0)
    logger.debug(
        "after deduplication, training data amounts: %s, testing data amounts: %s"
        % (len(unique_train_index), len(unique_test_index)))
    logger.debug("Splitted dataset successfully")

    logger.debug("Generating one hot vector for categorical data")

    extra_crispr_df = crispr[config.extra_categorical_features +
                             config.extra_numerical_features]

    n_values = [pam_code] + ([2] * (len(config.extra_categorical_features) - 1)
                             ) if config.with_pam else [2] * len(
                                 config.extra_categorical_features)
    process_features.process_categorical_features(extra_crispr_df, n_values)
    extra_x_train, extra_x_test = extra_crispr_df.loc[
        train_index, :].values, extra_crispr_df.loc[test_index, :].values
    logger.debug("Generating on hot vector for categorical data successfully")

    logger.debug("Seperate forward and reverse seq")
    x_train = x_train.values
    for_input_len = config.seq_len - config.word_len + 1
    for_input, rev_input, for_cnn = x_train[:, :
                                            for_input_len], x_train[:,
                                                                    for_input_len:
                                                                    2 *
                                                                    for_input_len], x_train[:,
                                                                                            2
                                                                                            *
                                                                                            for_input_len:]
    x_test = x_test.values
    for_x_test, rev_x_test, for_cnn_test = x_test[:, :
                                                  for_input_len], x_test[:,
                                                                         for_input_len:
                                                                         2 *
                                                                         for_input_len], x_test[:,
                                                                                                2
                                                                                                *
                                                                                                for_input_len:]
    off_target_X_train = off_target_X_train.values
    off_target_X_test = off_target_X_test.values
    if not config.off_target:
        off_target_X_train, off_target_X_test = np.empty(
            shape=[off_target_X_train.shape[0], 0]), np.empty(
                shape=[off_target_X_test.shape[0], 0])

    if (not config.rev_seq) or (config.model_type == 'mixed'):
        rev_input, rev_x_test = np.empty(
            shape=[rev_input.shape[0], 0]), np.empty(
                shape=[rev_x_test.shape[0], 0])

    y_train = y_train.values
    filter = y_train.flatten() > 0
    y_test = y_test.values

    if config.ml_train:

        try:
            ml_train(X, extra_crispr_df, y, train_index, test_index)

        except:
            logger.debug("Fail to use random forest")
        finally:
            h2o.cluster().shutdown()
        return

    logger.debug("Building the RNN graph")
    weight_matrix = [utils.get_weight_matrix()
                     ] if config.word2vec_weight_matrix else None
    for_seq_input = Input(shape=(for_input.shape[1], ))
    rev_seq_input = Input(shape=(rev_input.shape[1], ))
    for_cnn_input = Input(shape=(for_cnn.shape[1], ))
    bio_features = Input(shape=(extra_x_train.shape[1], ))
    off_target_features = Input(shape=(off_target_X_train.shape[1], ))
    all_features = Input(shape=(for_input.shape[1] + rev_input.shape[1] +
                                extra_x_train.shape[1] +
                                off_target_X_train.shape[1], ))
    if not config.ensemble:
        crispr_model = models.CrisprCasModel(
            bio_features=bio_features,
            for_seq_input=for_seq_input,
            rev_seq_input=rev_seq_input,
            weight_matrix=weight_matrix,
            off_target_features=off_target_features,
            all_features=all_features).get_model()
    else:
        crispr_model = models.CrisprCasModel(
            bio_features=bio_features,
            for_seq_input=for_seq_input,
            rev_seq_input=rev_seq_input,
            for_cnn_input=for_cnn_input,
            weight_matrix=weight_matrix,
            off_target_features=off_target_features,
            all_features=all_features).get_model()

    if config.retraining:
        loaded_model = load_model(config.retraining_model,
                                  custom_objects={
                                      'revised_mse_loss':
                                      utils.revised_mse_loss,
                                      'tf': tf
                                  })
        for layer in loaded_model.layers:
            print(layer.name)

        if config.model_type == 'cnn':

            for_layer = loaded_model.get_layer(name='embedding_1')
            for_layer.trainable = config.fine_tune_trainable

            full_connected = loaded_model.get_layer(name='sequential_6')

        elif (config.model_type == 'mixed') or (config.model_type
                                                == 'ensemble'):

            for_layer = loaded_model.get_layer(name='sequential_5')
            if config.frozen_embedding_only:
                for_layer = for_layer.get_layer(name='embedding_1')
            for_layer.trainable = config.fine_tune_trainable

            cnn_layer = loaded_model.get_layer(name='embedding_2')
            cnn_layer.trainable = config.fine_tune_trainable
            if not config.frozen_embedding_only:
                cnn_layer_1 = loaded_model.get_layer(name='sequential_3')
                cnn_layer_2 = loaded_model.get_layer(name='sequential_4')
                cnn_layer_1.trainable = config.fine_tune_trainable
                cnn_layer_2.trainable = config.fine_tune_trainable

            full_connected = loaded_model.get_layer(name='sequential_6')

        else:
            for_layer = loaded_model.get_layer(name='sequential_5')
            if config.frozen_embedding_only:

                for_layer = for_layer.get_layer(name='embedding_1')
            for_layer.trainable = config.fine_tune_trainable
            if config.rev_seq:
                rev_layer = loaded_model.get_layer(name='sequential_2')
                if config.frozen_embedding_only:
                    rev_layer = rev_layer.get_layer(name='embedding_2')
                rev_layer.trainable = config.fine_tune_trainable
                full_connected = loaded_model.get_layer(name='sequential_3')
            else:
                full_connected = loaded_model.get_layer(name='sequential_6')

        for i in range(
                int((len(full_connected.layers) / 4) *
                    (1 - config.fullly_connected_train_fraction))):

            dense_layer = full_connected.get_layer(name='dense_' + str(i + 1))
            dense_layer.trainable = config.fine_tune_trainable

        crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
            loaded_model)

    utils.output_model_info(crispr_model)
    logger.debug("Built the RNN model successfully")

    try:
        if config.training:
            logger.debug("Training the model")
            # x_train = x_train.values.astype('int32').reshape((-1, 21, 200))
            checkpoint = ModelCheckpoint(config.temp_hdf5_path,
                                         verbose=1,
                                         save_best_only=True,
                                         period=1)
            reduce_lr = LearningRateScheduler(utils.cosine_decay_lr)

            logger.debug("augmenting data")
            processed_for_input = utils.augment_data(
                for_input, filter=filter,
                is_seq=True) if config.augment_data else for_input

            if config.augment_data:
                if rev_input.shape[0] and rev_input.shape[1]:
                    processed_rev_input = utils.augment_data(rev_input,
                                                             filter=filter,
                                                             is_seq=True,
                                                             is_rev=True)
                else:
                    processed_rev_input = utils.augment_data(rev_input,
                                                             filter=filter)
            else:
                processed_rev_input = rev_input

            processed_off_target_X_train = utils.augment_data(
                off_target_X_train,
                filter=filter) if config.augment_data else off_target_X_train
            processed_extra_x_train = utils.augment_data(
                extra_x_train,
                filter=filter) if config.augment_data else extra_x_train
            processed_y_train = utils.augment_data(
                y_train, filter=filter) if config.augment_data else y_train
            logger.debug("augmented data successfully")

            logger.debug("selecting %d data for training" %
                         (config.retraining_datasize * len(processed_y_train)))
            index_range = list(range(len(processed_y_train)))
            np.random.shuffle(index_range)
            selected_index = index_range[:int(config.retraining_datasize *
                                              len(processed_y_train))]
            logger.debug("selecting %d data for training" %
                         (config.retraining_datasize * len(processed_y_train)))

            features_list = [
                processed_for_input[selected_index],
                processed_rev_input[selected_index],
                processed_off_target_X_train[selected_index],
                processed_extra_x_train[selected_index]
            ]

            if config.ensemble:
                processed_for_cnn = utils.augment_data(
                    for_cnn, filter=filter,
                    is_seq=True) if config.augment_data else for_cnn
                features_list.append(processed_for_cnn[selected_index])
                print("ensemble")
                print(len(features_list))

            training_history = utils.print_to_training_log(crispr_model.fit)(
                x=features_list,
                validation_split=0.05,
                y=processed_y_train[selected_index],
                epochs=config.n_epochs,
                batch_size=config.batch_size,
                verbose=2,
                callbacks=[checkpoint, reduce_lr])

            logger.debug("Saving history")
            with open(config.training_history, 'wb') as history_file:
                pickle.dump(training_history.history, history_file)
            logger.debug("Saved training history successfully")

            logger.debug("Trained crispr model successfully")

        else:
            logger.debug("Logging in old model")
            loaded_model = load_model(config.old_model_hdf5,
                                      custom_objects={
                                          'revised_mse_loss':
                                          utils.revised_mse_loss,
                                          'tf': tf
                                      })
            crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
                loaded_model)
            crispr_model.save(config.temp_hdf5_path)
            logger.debug("Load in model successfully")

    except KeyboardInterrupt as e:

        logger.debug("Loading model")
        loaded_model = load_model(config.temp_hdf5_path,
                                  custom_objects={
                                      'revised_mse_loss':
                                      utils.revised_mse_loss,
                                      'tf': tf
                                  })
        crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
            loaded_model)
        logger.debug("Load in model successfully")

    logger.debug("Persisting model")
    # serialize weights to HDF5
    crispr_model.save(config.hdf5_path)
    print("Saved model to disk")

    logger.debug("Loading best model for testing")
    loaded_model = load_model(config.temp_hdf5_path,
                              custom_objects={
                                  'revised_mse_loss': utils.revised_mse_loss,
                                  'tf': tf
                              })
    crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
        loaded_model)
    logger.debug("Load in model successfully")

    logger.debug("Predicting data with best model")
    train_list = [
        for_input[unique_train_index], rev_input[unique_train_index],
        off_target_X_train[unique_train_index],
        extra_x_train[unique_train_index]
    ]
    if config.ensemble:
        train_list.append(for_cnn[unique_train_index])
    train_prediction = crispr_model.predict(x=train_list)
    train_performance = spearmanr(train_prediction,
                                  y_train[unique_train_index])
    logger.debug(
        "GRU model spearman correlation coefficient for training dataset is: %s"
        % str(train_performance))

    get_prediction = getattr(sys.modules[__name__],
                             "get_prediction_" + config.test_method,
                             get_prediction_group)
    test_list = [for_x_test, rev_x_test, off_target_X_test, extra_x_test]
    if config.ensemble:
        test_list.append(for_cnn_test)
    performance, prediction = get_prediction(crispr_model, test_index_list,
                                             unique_test_index, y_test,
                                             test_list)
    logger.debug("GRU model spearman correlation coefficient: %s" %
                 str(performance))

    logger.debug("Loading last model for testing")
    loaded_model = load_model(config.hdf5_path,
                              custom_objects={
                                  'revised_mse_loss': utils.revised_mse_loss,
                                  'tf': tf
                              })
    crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
        loaded_model)
    logger.debug("Load in model successfully")

    logger.debug("Predicting data with last model")
    last_train_prediction = crispr_model.predict(x=train_list)
    last_train_performance = spearmanr(last_train_prediction,
                                       y_train[unique_train_index])
    utils.output_config_info()
    logger.debug(
        "GRU model spearman correlation coefficient for training dataset is: %s"
        % str(last_train_performance))

    last_performance, last_prediction = get_prediction(crispr_model,
                                                       test_index_list,
                                                       unique_test_index,
                                                       y_test, test_list)
    logger.debug("GRU model spearman correlation coefficient: %s" %
                 str(last_performance))

    logger.debug("Saving test and prediction data plot")
    if last_performance > performance:
        prediction = last_prediction
    utils.ytest_and_prediction_output(y_test[unique_test_index], prediction)
    logger.debug("Saved test and prediction data plot successfully")

    if config.check_feature_importance:
        if performance > last_performance:
            loaded_model = load_model(config.temp_hdf5_path,
                                      custom_objects={
                                          'revised_mse_loss':
                                          utils.revised_mse_loss,
                                          'tf': tf
                                      })
            crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
                loaded_model)
        logger.debug("Getting features ranks")
        names = []
        names += ["for_" + str(i) for i in range(for_input.shape[1])]
        names += ["rev_" + str(i) for i in range(rev_input.shape[1])]
        names += ["off_" + str(i) for i in range(off_target_X_train.shape[1])]
        names += config.extra_categorical_features + config.extra_numerical_features
        ranker = feature_imp.InputPerturbationRank(names)
        feature_ranks = ranker.rank(
            20, y_test[unique_test_index], crispr_model,
            [data[unique_test_index] for data in test_list])
        feature_ranks_df = pd.DataFrame(feature_ranks)
        feature_ranks_df.to_csv(config.feature_importance_path, index=False)
        logger.debug("Get features ranks successfully")
Ejemplo n.º 5
0
def run():

    data_pre = OT_crispr_attn.data_preparer()
    print(data_pre.get_crispr_preview())
    X = data_pre.prepare_x(mismatch=True, trg_seq_col='Target sequence')
    y = data_pre.get_labels(binary=True)
    data_pre.persist_data()
    print(X.head())
    logger.debug("{0!r}".format(data_pre.feature_length_map))
    train_index, test_index = data_pre.train_test_split(n_split=5)
    logger.debug("{0!r}".format(
        set(data_pre.crispr.loc[test_index, :]['sequence'])))
    logger.debug("{0!r}".format(
        set(data_pre.crispr.loc[train_index, :]['sequence'])))
    #assert len(set(data_pre.crispr.loc[test_index, :]['sequence']) & set(data_pre.crispr.loc[train_index, :]['sequence'])) == 0
    logger.debug("{0!r}".format(train_index))
    logger.debug("{0!r}".format(test_index))
    logger.debug("training data amounts: %s, testing data amounts: %s" %
                 (len(train_index), len(test_index)))
    torch.manual_seed(0)

    if config.test_cellline:
        test_cellline_index = data_pre.crispr[data_pre.crispr['cellline'] ==
                                              config.test_cellline].index
        test_index = test_cellline_index.intersection(test_index)

    ros = RandomOverSampler(random_state=42)
    _ = ros.fit_resample(X.loc[train_index, :], y.loc[train_index, :])
    new_train_index = train_index[ros.sample_indices_]
    oversample_train_index = list(new_train_index)
    random.shuffle(oversample_train_index)

    # sep = int(len(train_eval_index_list) * 0.9)
    # train_index_list = train_eval_index_list[:sep]
    # eval_index_list = train_eval_index_list[sep:]

    assert len(set(oversample_train_index) & set(test_index)) == 0
    assert len(set(oversample_train_index) & set(train_index)) == len(
        set(train_index))
    partition = {
        'train': oversample_train_index,
        'train_val': train_index,
        'eval': list(test_index),
        'test': list(test_index)
    }

    labels = {
        key: value
        for key, value in zip(list(range(len(y))), list(y.values.reshape(-1)))
    }

    train_params = {'batch_size': config.batch_size, 'shuffle': True}
    train_bg_params = {'batch_size': config.batch_size, 'shuffle': True}
    eval_params = {'batch_size': len(test_index), 'shuffle': False}
    test_params = {'batch_size': len(test_index), 'shuffle': False}

    logger.debug("Preparing datasets ... ")
    training_set = my_data.MyDataset(partition['train'], labels)
    training_generator = data.DataLoader(training_set, **train_params)

    training_bg_set = my_data.MyDataset(partition['train_val'], labels)
    training_bg_generator = data.DataLoader(training_bg_set, **train_bg_params)

    validation_set = my_data.MyDataset(partition['eval'], labels)
    validation_generator = data.DataLoader(validation_set, **eval_params)

    test_set = my_data.MyDataset(partition['test'], labels)
    test_generator = data.DataLoader(test_set, **test_params)

    logger.debug("I might need to augment data")

    logger.debug("Building the scaled dot product attention model")
    crispr_model = attention_model.get_OT_model(data_pre.feature_length_map,
                                                classifier=True)
    best_crispr_model = attention_model.get_OT_model(
        data_pre.feature_length_map, classifier=True)
    crispr_model.to(device2)
    best_crispr_model.to(device2)
    best_cv_roc_auc_scores = 0
    optimizer = torch.optim.Adam(crispr_model.parameters(),
                                 lr=config.start_lr,
                                 weight_decay=config.lr_decay,
                                 betas=(0.9, 0.98),
                                 eps=1e-9)

    if config.retraining:

        logger.debug("I need to load a old trained model")
        logger.debug("I might need to freeze some of the weights")

    logger.debug("Built the RNN model successfully")

    try:
        if config.training:

            logger.debug("Training the model")
            nllloss_visualizer = torch_visual.VisTorch(env_name='NLLLOSS')
            roc_auc_visualizer = torch_visual.VisTorch(env_name='ROC_AUC')
            pr_auc_visualizer = torch_visual.VisTorch(env_name='PR_AUC')

            for epoch in range(config.n_epochs):

                crispr_model.train()
                start = time()
                cur_epoch_train_loss = []
                train_total_loss = 0
                i = 0

                # Training
                for local_batch, local_labels in training_generator:
                    i += 1
                    # Transfer to GPU
                    local_batch, local_labels = local_batch.float().to(
                        device2), local_labels.long().to(device2)
                    # seq_local_batch = local_batch.narrow(dim=1, start=0, length=for_input_len).long()
                    # extra_local_batch = local_batch.narrow(dim=1, start=for_input_len, length=extra_input_len)
                    # Model computations
                    preds = crispr_model(local_batch)
                    ys = local_labels.contiguous().view(-1)
                    optimizer.zero_grad()
                    assert preds.size(0) == ys.size(0)
                    #loss = F.nll_loss(preds, ys)
                    criterion = torch.nn.CrossEntropyLoss()
                    loss = criterion(preds, ys)
                    loss.backward()
                    optimizer.step()

                    train_total_loss += loss.item()

                    n_iter = 50
                    if i % n_iter == 0:
                        sample_size = len(train_index)
                        p = int(100 * i * config.batch_size / sample_size)
                        avg_loss = train_total_loss / n_iter
                        logger.debug("   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" % \
                                                 ((time() - start) // 60, epoch, "".join('#' * (p // 5)),
                                                  "".join(' ' * (20 - (p // 5))), p, avg_loss))
                        train_total_loss = 0
                        cur_epoch_train_loss.append(avg_loss)

                ### Evaluation
                train_val_i = 0
                train_val_total_loss = 0
                train_val_loss = []
                # train_val_roc_auc = 0
                # train_val_pr_auc = 0
                train_val_preds = []
                train_val_ys = []
                n_pos, n_neg = 0, 0

                with torch.set_grad_enabled(False):

                    crispr_model.eval()
                    for local_batch, local_labels in training_bg_generator:
                        train_val_i += 1
                        local_labels_on_cpu = np.array(local_labels).reshape(
                            -1)
                        train_val_ys.append(local_labels_on_cpu)
                        n_pos += sum(local_labels_on_cpu)
                        n_neg += len(local_labels_on_cpu) - sum(
                            local_labels_on_cpu)
                        # Transfer to GPU
                        local_batch, local_labels = local_batch.float().to(
                            device2), local_labels.long().to(device2)
                        preds = crispr_model(local_batch)
                        assert preds.size(0) == local_labels.size(0)
                        criterion = torch.nn.CrossEntropyLoss()
                        nllloss_val = criterion(preds, local_labels).item()
                        # nllloss_val = F.nll_loss(preds, local_labels).item()
                        train_val_total_loss += nllloss_val
                        prediction_on_cpu = preds.cpu().numpy()[:, 1]
                        train_val_preds.append(prediction_on_cpu)

                        n_iter = 10
                        if train_val_i % n_iter == 0:
                            avg_loss = train_val_total_loss / n_iter
                            train_val_loss.append(avg_loss)
                            train_val_total_loss = 0

                preds = np.concatenate(tuple(train_val_preds))
                ys = np.concatenate(tuple(train_val_ys))
                train_val_roc_auc = roc_auc_score(ys, preds)
                train_val_pr_auc = average_precision_score(ys, preds)
                logger.debug(
                    "{0!r} positive samples and {1!r} negative samples".format(
                        n_pos, n_neg))
                logger.debug(
                    "Validation nllloss is {0}, Validation roc_auc is {1!r} and Validation "
                    "pr_auc correlation is {2!r}".format(
                        np.mean(train_val_loss), train_val_roc_auc,
                        train_val_pr_auc))
                # nllloss_visualizer.plot_loss(epoch, np.mean(cur_epoch_train_loss), np.mean(train_val_loss), loss_type='nllloss')
                # roc_auc_visualizer.plot_loss(epoch, train_val_roc_auc, loss_type='roc_auc', ytickmin=0, ytickmax=1)
                # pr_auc_visualizer.plot_loss(epoch, train_val_pr_auc, loss_type='pr_auc', ytickmin=0, ytickmax=1)

                ### Evaluation
                val_i = 0
                val_total_loss = 0
                val_loss = []
                val_preds = []
                val_ys = []

                with torch.set_grad_enabled(False):

                    crispr_model.eval()
                    for local_batch, local_labels in validation_generator:
                        val_i += 1
                        local_labels_on_cpu = np.array(local_labels).reshape(
                            -1)
                        val_ys.append(local_labels_on_cpu)
                        n_pos = sum(local_labels_on_cpu)
                        n_neg = len(local_labels_on_cpu) - sum(
                            local_labels_on_cpu)
                        logger.debug(
                            "{0!r} positive samples and {1!r} negative samples"
                            .format(n_pos, n_neg))
                        local_batch, local_labels = local_batch.float().to(
                            device2), local_labels.long().to(device2)
                        preds = crispr_model(local_batch)
                        assert preds.size(0) == local_labels.size(0)
                        criterion = torch.nn.CrossEntropyLoss()
                        nllloss_val = criterion(preds, local_labels).item()
                        #nllloss_val = F.nll_loss(preds, local_labels).item()
                        prediction_on_cpu = preds.cpu().numpy()[:, 1]
                        val_preds.append(prediction_on_cpu)
                        val_total_loss += nllloss_val

                        n_iter = 1
                        if val_i % n_iter == 0:
                            avg_loss = val_total_loss / n_iter
                            val_loss.append(avg_loss)
                            val_total_loss = 0

                preds = np.concatenate(tuple(val_preds))
                ys = np.concatenate(tuple(val_ys))
                val_roc_auc = roc_auc_score(ys, preds)
                val_pr_auc = average_precision_score(ys, preds)
                logger.debug(
                    "Test NLLloss is {0}, Test roc_auc is {1!r} and Test "
                    "pr_auc is {2!r}".format(np.mean(val_loss), val_roc_auc,
                                             val_pr_auc))
                nllloss_visualizer.plot_loss(epoch,
                                             np.mean(cur_epoch_train_loss),
                                             np.mean(val_loss),
                                             loss_type='nllloss')
                roc_auc_visualizer.plot_loss(epoch,
                                             train_val_roc_auc,
                                             val_roc_auc,
                                             loss_type='roc_auc',
                                             ytickmin=0,
                                             ytickmax=1)
                pr_auc_visualizer.plot_loss(epoch,
                                            train_val_pr_auc,
                                            val_pr_auc,
                                            loss_type='pr_auc',
                                            ytickmin=0,
                                            ytickmax=1)

                if best_cv_roc_auc_scores < val_roc_auc:
                    best_cv_roc_auc_scores = val_roc_auc
                    best_crispr_model.load_state_dict(
                        crispr_model.state_dict())

            logger.debug("Saving training history")

            logger.debug("Saved training history successfully")

            logger.debug("Trained crispr model successfully")

        else:
            logger.debug("loading in old model")

            logger.debug("Load in model successfully")

    except KeyboardInterrupt as e:

        logger.debug("Loading model")
        logger.debug("loading some intermediate step's model")
        logger.debug("Load in model successfully")

    logger.debug("Persisting model")
    # serialize weights to HDF5
    save(best_crispr_model, config.hdf5_path)
    save(best_crispr_model.state_dict(), config.hdf5_path_state)
    logger.debug("Saved model to disk")

    save_output = []
    if 'essentiality' in config.extra_numerical_features:
        save_output.append(data_pre.crispr.loc[test_index, 'essentiality'])
    test_model(best_crispr_model, test_generator, save_output)

    if config.check_feature_importance:

        logger.debug("Getting features ranks")
        names = []
        names += [
            "src_" + str(i) for i in range(data_pre.feature_length_map[0][1])
        ]
        if data_pre.feature_length_map[1] is not None:
            names += [
                "trg_" + str(i)
                for i in range(data_pre.feature_length_map[1][1] -
                               data_pre.feature_length_map[1][0])
            ]
        if data_pre.feature_length_map[2] is not None:
            names += config.extra_categorical_features + config.extra_numerical_features
        ranker = feature_imp.InputPerturbationRank(names)
        feature_ranks = ranker.rank(
            2,
            y.loc[train_index, :].values,
            best_crispr_model,
            [torch.FloatTensor(X.loc[train_index, :].values)],
            torch=True,
            classifier=True)
        feature_ranks_df = pd.DataFrame(feature_ranks)
        feature_ranks_df.to_csv(config.feature_importance_path, index=False)
        logger.debug("Get features ranks successfully")