Ejemplo n.º 1
0
        def cov_model(kernel_size = (3,1), pool_size = (21-3+1,1), levels = config.cnn_levels):

            model = Sequential()
            model.add(Conv2D(levels[0], input_shape=(input_len, vec_dim, 1), kernel_size=(kernel_size[0], 1),
                              padding='same'))
            model.add(BatchNormalization())
            model.add(Activation('relu'))
            # output shape is (None, len, dim, 32)
            model.add(MaxPooling2D(pool_size=(2,1), padding='same'))
            # output shape is (None, len+1/2, dim, 32)

            for i in range(len(levels)-2):
                model.add(Conv2D(levels[i+1], kernel_size=(kernel_size[0], 1), padding='same'))
                # output shape is (None, len+1/2, dim, 64)
                model.add(BatchNormalization())
                model.add(Activation('relu'))
                model.add(MaxPooling2D(pool_size=(2, 1), padding='same'))
                # output shape is (None, (len+1/2+1)/2, dim, 64)

            last_kernal_size = (3, kernel_size[1])
            model.add(Conv2D(config.cnn_levels[-1], kernel_size=last_kernal_size, strides= (1, kernel_size[1]), padding='same'))
            model.add(Activation('relu'))
            # output shape is (None, (len+1/2+1)/2-ker_len+1, 1, 128)
            last_pool_len = pool_size[0]
            for _ in range(len(levels)-1):
                last_pool_len =(last_pool_len + 1) / 2
            last_pool_size = (last_pool_len, 1)
            model.add(MaxPooling2D(pool_size=last_pool_size, padding='valid'))
            model.add(Flatten())
            utils.output_model_info(model)
            return model
Ejemplo n.º 2
0
    def __fully_connected(self, nodes_unit_nums, input_len, name_suffix=''):

        model = Sequential(name='FC_' + name_suffix)

        for i in range(len(nodes_unit_nums)):

            if i == 0:
                model.add(
                    Dense(nodes_unit_nums[i],
                          input_shape=(input_len, ),
                          kernel_constraint=maxnorm(config.maxnorm)))
            else:
                model.add(
                    Dense(nodes_unit_nums[i],
                          kernel_constraint=maxnorm(config.maxnorm)))

            if config.add_norm:
                model.add(BatchNormalization(momentum=0))
            model.add(
                Activation(
                    config.activation_method[i %
                                             len(config.activation_method)]))
            model.add(Dropout(rate=config.dropout))
        utils.output_model_info(model)
        return model
Ejemplo n.º 3
0
    def __embedding_cnn(self, name_suffix='', nt=3):

        weights = self.weight_matrix
        voca_size = config.embedding_voca_size
        vec_dim = config.embedding_vec_dim
        input_len = self.seq_input_len

        if nt == 1:
            weights = None
            voca_size = 5
            vec_dim = 8
            input_len = config.seq_len

        model = Sequential(name='embedding_and_cnn_' + name_suffix)
        model.add(
            Embedding(voca_size,
                      vec_dim,
                      weights=weights,
                      input_length=input_len,
                      trainable=True))
        model.add(Reshape((1, input_len, vec_dim)))
        model.add(Conv2D(32, kernel_size=(1, 4), strides=2, padding='same'))
        if config.add_norm:
            model.add(BatchNormalization(momentum=0))
        model.add(Activation('swish'))
        # (1, 10, 32)
        model.add(Conv2D(64, kernel_size=(1, 4), strides=2, padding='same'))
        if config.add_norm:
            model.add(BatchNormalization(momentum=0))
        model.add(Activation('swish'))
        # (1, 5, 64)
        model.add(Conv2D(128, kernel_size=(1, 4), strides=2, padding='same'))
        if config.add_norm:
            model.add(BatchNormalization(momentum=0))
        model.add(Activation('swish'))
        # (1,3,128)
        model.add(Conv2D(256, kernel_size=(1, 3), strides=2, padding='valid'))
        if config.add_norm:
            model.add(BatchNormalization(momentum=0))
        model.add(Activation('swish'))
        model.add(Flatten())
        model.add(Dense(units=config.cnn_levels[-1]))
        utils.output_model_info(model)
        return model
Ejemplo n.º 4
0
def built_model(x_train, extra_crispr_df):

    logger.debug("Building the RNN graph")
    weight_matrix = [utils.get_weight_matrix()
                     ] if config.word2vec_weight_matrix else None
    x_train_len = x_train.shape[1]
    extra_x_len = extra_crispr_df.shape[1]
    for_seq_input = Input(shape=(x_train_len, ))
    bio_features = Input(shape=(extra_x_len, ))
    crispr_model = models.CrisprCasModel(
        bio_features=bio_features,
        for_seq_input=for_seq_input,
        weight_matrix=weight_matrix).get_model()
    if config.transfer_learning:
        crispr_model = _transfer_learning_model()

    utils.output_model_info(crispr_model)
    logger.debug("Built the RNN model successfully")
    return crispr_model
Ejemplo n.º 5
0
def run():

    logger.debug("Reading in the crispr dataset %s" % config.input_dataset)
    crispr = pd.read_csv(config.input_dataset)
    crispr['PAM'] = crispr['sequence'].str[-3:]
    if config.log_cen:
        crispr['essentiality'] = np.log(crispr['essentiality'] * 100 + 1)
    if config.with_pam:
        pam_code = 8
    else:
        pam_code = 0
    # scale_features
    process_features.scale_features(crispr)
    process_features.scale_output(crispr)
    logger.debug("Read in data successfully")

    logger.debug("Transforming data")
    X_for = crispr.loc[:, 'sequence'].apply(
        lambda seq: utils.split_seqs(seq[:config.seq_len]))
    X_rev = crispr.loc[:, 'sequence'].apply(
        lambda seq: utils.split_seqs(seq[config.seq_len - 1::-1]))
    X_cnn = crispr.loc[:, 'sequence'].apply(
        lambda seq: utils.split_seqs(seq[:config.seq_len], nt=1))
    X = pd.concat([X_for, X_rev, X_cnn], axis=1)
    logger.debug("Get sequence sucessfully")
    off_target_X = pd.DataFrame(np.empty(shape=[X_for.shape[0], 0]))
    # off_target_X = crispr.loc[:, 'sequence'].apply(lambda seq: utils.map_to_matrix(seq, 1, 22))
    # y = pd.DataFrame(np.abs(crispr[config.y].copy()) * 10)
    y = pd.DataFrame(crispr[config.y].copy() * 8)
    logger.debug("Transformed data successfully")

    logger.debug(
        "Starting to prepare for splitting dataset to training dataset and testing dataset based on genes"
    )
    logger.debug("Generating groups based on gene names")
    if config.group:
        crispr.loc[:, "group"] = pd.Categorical(crispr.loc[:, config.group])
    logger.debug("Generated groups information successfully")

    logger.debug("Splitting dataset")
    if os.path.exists(config.train_index) and os.path.exists(
            config.test_index):
        train_index = pickle.load(open(config.train_index, "rb"))
        test_index = pickle.load(open(config.test_index, "rb"))
    else:
        train_test_split = getattr(process_features,
                                   config.split_method + "_split",
                                   process_features.regular_split)
        train_index, test_index = train_test_split(crispr,
                                                   group_col=config.group_col,
                                                   n_split=max(
                                                       len(crispr) / 100, 10),
                                                   rd_state=7)

        with open(config.train_index, 'wb') as train_file:
            pickle.dump(train_index, train_file)
        with open(config.test_index, 'wb') as test_file:
            pickle.dump(test_index, test_file)

    if config.test_cellline:
        test_cellline_index = crispr[crispr['cellline'] ==
                                     config.test_cellline].index
        test_index = test_cellline_index.intersection(test_index)

    test_index_list = [
        x.index
        for _, x in crispr.loc[test_index, :].reset_index().groupby('group')
        if len(x)
    ] if config.test_method == 'group' else []
    logger.debug("Splitted data successfully")

    logger.debug("training data amounts: %s, testing data amounts: %s" %
                 (len(train_index), len(test_index)))
    x_train, x_test, y_train, y_test, off_target_X_train, off_target_X_test = \
                                       X.loc[train_index, :], X.loc[test_index, :], \
                                       y.loc[train_index, :], y.loc[test_index, :], \
                                       off_target_X.loc[train_index, :], off_target_X.loc[test_index, :]

    _, unique_train_index = np.unique(pd.concat([x_train, y_train], axis=1),
                                      return_index=True,
                                      axis=0)
    _, unique_test_index = np.unique(pd.concat([x_test, y_test], axis=1),
                                     return_index=True,
                                     axis=0)
    logger.debug(
        "after deduplication, training data amounts: %s, testing data amounts: %s"
        % (len(unique_train_index), len(unique_test_index)))
    logger.debug("Splitted dataset successfully")

    logger.debug("Generating one hot vector for categorical data")

    extra_crispr_df = crispr[config.extra_categorical_features +
                             config.extra_numerical_features]

    n_values = [pam_code] + ([2] * (len(config.extra_categorical_features) - 1)
                             ) if config.with_pam else [2] * len(
                                 config.extra_categorical_features)
    process_features.process_categorical_features(extra_crispr_df, n_values)
    extra_x_train, extra_x_test = extra_crispr_df.loc[
        train_index, :].values, extra_crispr_df.loc[test_index, :].values
    logger.debug("Generating on hot vector for categorical data successfully")

    logger.debug("Seperate forward and reverse seq")
    x_train = x_train.values
    for_input_len = config.seq_len - config.word_len + 1
    for_input, rev_input, for_cnn = x_train[:, :
                                            for_input_len], x_train[:,
                                                                    for_input_len:
                                                                    2 *
                                                                    for_input_len], x_train[:,
                                                                                            2
                                                                                            *
                                                                                            for_input_len:]
    x_test = x_test.values
    for_x_test, rev_x_test, for_cnn_test = x_test[:, :
                                                  for_input_len], x_test[:,
                                                                         for_input_len:
                                                                         2 *
                                                                         for_input_len], x_test[:,
                                                                                                2
                                                                                                *
                                                                                                for_input_len:]
    off_target_X_train = off_target_X_train.values
    off_target_X_test = off_target_X_test.values
    if not config.off_target:
        off_target_X_train, off_target_X_test = np.empty(
            shape=[off_target_X_train.shape[0], 0]), np.empty(
                shape=[off_target_X_test.shape[0], 0])

    if (not config.rev_seq) or (config.model_type == 'mixed'):
        rev_input, rev_x_test = np.empty(
            shape=[rev_input.shape[0], 0]), np.empty(
                shape=[rev_x_test.shape[0], 0])

    y_train = y_train.values
    filter = y_train.flatten() > 0
    y_test = y_test.values

    if config.ml_train:

        try:
            ml_train(X, extra_crispr_df, y, train_index, test_index)

        except:
            logger.debug("Fail to use random forest")
        finally:
            h2o.cluster().shutdown()
        return

    logger.debug("Building the RNN graph")
    weight_matrix = [utils.get_weight_matrix()
                     ] if config.word2vec_weight_matrix else None
    for_seq_input = Input(shape=(for_input.shape[1], ))
    rev_seq_input = Input(shape=(rev_input.shape[1], ))
    for_cnn_input = Input(shape=(for_cnn.shape[1], ))
    bio_features = Input(shape=(extra_x_train.shape[1], ))
    off_target_features = Input(shape=(off_target_X_train.shape[1], ))
    all_features = Input(shape=(for_input.shape[1] + rev_input.shape[1] +
                                extra_x_train.shape[1] +
                                off_target_X_train.shape[1], ))
    if not config.ensemble:
        crispr_model = models.CrisprCasModel(
            bio_features=bio_features,
            for_seq_input=for_seq_input,
            rev_seq_input=rev_seq_input,
            weight_matrix=weight_matrix,
            off_target_features=off_target_features,
            all_features=all_features).get_model()
    else:
        crispr_model = models.CrisprCasModel(
            bio_features=bio_features,
            for_seq_input=for_seq_input,
            rev_seq_input=rev_seq_input,
            for_cnn_input=for_cnn_input,
            weight_matrix=weight_matrix,
            off_target_features=off_target_features,
            all_features=all_features).get_model()

    if config.retraining:
        loaded_model = load_model(config.retraining_model,
                                  custom_objects={
                                      'revised_mse_loss':
                                      utils.revised_mse_loss,
                                      'tf': tf
                                  })
        for layer in loaded_model.layers:
            print(layer.name)

        if config.model_type == 'cnn':

            for_layer = loaded_model.get_layer(name='embedding_1')
            for_layer.trainable = config.fine_tune_trainable

            full_connected = loaded_model.get_layer(name='sequential_6')

        elif (config.model_type == 'mixed') or (config.model_type
                                                == 'ensemble'):

            for_layer = loaded_model.get_layer(name='sequential_5')
            if config.frozen_embedding_only:
                for_layer = for_layer.get_layer(name='embedding_1')
            for_layer.trainable = config.fine_tune_trainable

            cnn_layer = loaded_model.get_layer(name='embedding_2')
            cnn_layer.trainable = config.fine_tune_trainable
            if not config.frozen_embedding_only:
                cnn_layer_1 = loaded_model.get_layer(name='sequential_3')
                cnn_layer_2 = loaded_model.get_layer(name='sequential_4')
                cnn_layer_1.trainable = config.fine_tune_trainable
                cnn_layer_2.trainable = config.fine_tune_trainable

            full_connected = loaded_model.get_layer(name='sequential_6')

        else:
            for_layer = loaded_model.get_layer(name='sequential_5')
            if config.frozen_embedding_only:

                for_layer = for_layer.get_layer(name='embedding_1')
            for_layer.trainable = config.fine_tune_trainable
            if config.rev_seq:
                rev_layer = loaded_model.get_layer(name='sequential_2')
                if config.frozen_embedding_only:
                    rev_layer = rev_layer.get_layer(name='embedding_2')
                rev_layer.trainable = config.fine_tune_trainable
                full_connected = loaded_model.get_layer(name='sequential_3')
            else:
                full_connected = loaded_model.get_layer(name='sequential_6')

        for i in range(
                int((len(full_connected.layers) / 4) *
                    (1 - config.fullly_connected_train_fraction))):

            dense_layer = full_connected.get_layer(name='dense_' + str(i + 1))
            dense_layer.trainable = config.fine_tune_trainable

        crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
            loaded_model)

    utils.output_model_info(crispr_model)
    logger.debug("Built the RNN model successfully")

    try:
        if config.training:
            logger.debug("Training the model")
            # x_train = x_train.values.astype('int32').reshape((-1, 21, 200))
            checkpoint = ModelCheckpoint(config.temp_hdf5_path,
                                         verbose=1,
                                         save_best_only=True,
                                         period=1)
            reduce_lr = LearningRateScheduler(utils.cosine_decay_lr)

            logger.debug("augmenting data")
            processed_for_input = utils.augment_data(
                for_input, filter=filter,
                is_seq=True) if config.augment_data else for_input

            if config.augment_data:
                if rev_input.shape[0] and rev_input.shape[1]:
                    processed_rev_input = utils.augment_data(rev_input,
                                                             filter=filter,
                                                             is_seq=True,
                                                             is_rev=True)
                else:
                    processed_rev_input = utils.augment_data(rev_input,
                                                             filter=filter)
            else:
                processed_rev_input = rev_input

            processed_off_target_X_train = utils.augment_data(
                off_target_X_train,
                filter=filter) if config.augment_data else off_target_X_train
            processed_extra_x_train = utils.augment_data(
                extra_x_train,
                filter=filter) if config.augment_data else extra_x_train
            processed_y_train = utils.augment_data(
                y_train, filter=filter) if config.augment_data else y_train
            logger.debug("augmented data successfully")

            logger.debug("selecting %d data for training" %
                         (config.retraining_datasize * len(processed_y_train)))
            index_range = list(range(len(processed_y_train)))
            np.random.shuffle(index_range)
            selected_index = index_range[:int(config.retraining_datasize *
                                              len(processed_y_train))]
            logger.debug("selecting %d data for training" %
                         (config.retraining_datasize * len(processed_y_train)))

            features_list = [
                processed_for_input[selected_index],
                processed_rev_input[selected_index],
                processed_off_target_X_train[selected_index],
                processed_extra_x_train[selected_index]
            ]

            if config.ensemble:
                processed_for_cnn = utils.augment_data(
                    for_cnn, filter=filter,
                    is_seq=True) if config.augment_data else for_cnn
                features_list.append(processed_for_cnn[selected_index])
                print("ensemble")
                print(len(features_list))

            training_history = utils.print_to_training_log(crispr_model.fit)(
                x=features_list,
                validation_split=0.05,
                y=processed_y_train[selected_index],
                epochs=config.n_epochs,
                batch_size=config.batch_size,
                verbose=2,
                callbacks=[checkpoint, reduce_lr])

            logger.debug("Saving history")
            with open(config.training_history, 'wb') as history_file:
                pickle.dump(training_history.history, history_file)
            logger.debug("Saved training history successfully")

            logger.debug("Trained crispr model successfully")

        else:
            logger.debug("Logging in old model")
            loaded_model = load_model(config.old_model_hdf5,
                                      custom_objects={
                                          'revised_mse_loss':
                                          utils.revised_mse_loss,
                                          'tf': tf
                                      })
            crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
                loaded_model)
            crispr_model.save(config.temp_hdf5_path)
            logger.debug("Load in model successfully")

    except KeyboardInterrupt as e:

        logger.debug("Loading model")
        loaded_model = load_model(config.temp_hdf5_path,
                                  custom_objects={
                                      'revised_mse_loss':
                                      utils.revised_mse_loss,
                                      'tf': tf
                                  })
        crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
            loaded_model)
        logger.debug("Load in model successfully")

    logger.debug("Persisting model")
    # serialize weights to HDF5
    crispr_model.save(config.hdf5_path)
    print("Saved model to disk")

    logger.debug("Loading best model for testing")
    loaded_model = load_model(config.temp_hdf5_path,
                              custom_objects={
                                  'revised_mse_loss': utils.revised_mse_loss,
                                  'tf': tf
                              })
    crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
        loaded_model)
    logger.debug("Load in model successfully")

    logger.debug("Predicting data with best model")
    train_list = [
        for_input[unique_train_index], rev_input[unique_train_index],
        off_target_X_train[unique_train_index],
        extra_x_train[unique_train_index]
    ]
    if config.ensemble:
        train_list.append(for_cnn[unique_train_index])
    train_prediction = crispr_model.predict(x=train_list)
    train_performance = spearmanr(train_prediction,
                                  y_train[unique_train_index])
    logger.debug(
        "GRU model spearman correlation coefficient for training dataset is: %s"
        % str(train_performance))

    get_prediction = getattr(sys.modules[__name__],
                             "get_prediction_" + config.test_method,
                             get_prediction_group)
    test_list = [for_x_test, rev_x_test, off_target_X_test, extra_x_test]
    if config.ensemble:
        test_list.append(for_cnn_test)
    performance, prediction = get_prediction(crispr_model, test_index_list,
                                             unique_test_index, y_test,
                                             test_list)
    logger.debug("GRU model spearman correlation coefficient: %s" %
                 str(performance))

    logger.debug("Loading last model for testing")
    loaded_model = load_model(config.hdf5_path,
                              custom_objects={
                                  'revised_mse_loss': utils.revised_mse_loss,
                                  'tf': tf
                              })
    crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
        loaded_model)
    logger.debug("Load in model successfully")

    logger.debug("Predicting data with last model")
    last_train_prediction = crispr_model.predict(x=train_list)
    last_train_performance = spearmanr(last_train_prediction,
                                       y_train[unique_train_index])
    utils.output_config_info()
    logger.debug(
        "GRU model spearman correlation coefficient for training dataset is: %s"
        % str(last_train_performance))

    last_performance, last_prediction = get_prediction(crispr_model,
                                                       test_index_list,
                                                       unique_test_index,
                                                       y_test, test_list)
    logger.debug("GRU model spearman correlation coefficient: %s" %
                 str(last_performance))

    logger.debug("Saving test and prediction data plot")
    if last_performance > performance:
        prediction = last_prediction
    utils.ytest_and_prediction_output(y_test[unique_test_index], prediction)
    logger.debug("Saved test and prediction data plot successfully")

    if config.check_feature_importance:
        if performance > last_performance:
            loaded_model = load_model(config.temp_hdf5_path,
                                      custom_objects={
                                          'revised_mse_loss':
                                          utils.revised_mse_loss,
                                          'tf': tf
                                      })
            crispr_model = models.CrisprCasModel.compile_transfer_learning_model(
                loaded_model)
        logger.debug("Getting features ranks")
        names = []
        names += ["for_" + str(i) for i in range(for_input.shape[1])]
        names += ["rev_" + str(i) for i in range(rev_input.shape[1])]
        names += ["off_" + str(i) for i in range(off_target_X_train.shape[1])]
        names += config.extra_categorical_features + config.extra_numerical_features
        ranker = feature_imp.InputPerturbationRank(names)
        feature_ranks = ranker.rank(
            20, y_test[unique_test_index], crispr_model,
            [data[unique_test_index] for data in test_list])
        feature_ranks_df = pd.DataFrame(feature_ranks)
        feature_ranks_df.to_csv(config.feature_importance_path, index=False)
        logger.debug("Get features ranks successfully")