예제 #1
0
    def __trim_dataframes(self):
        """self.__trim_dataframes(trim_data_source=True)

        This function trims three dataframes to make sure that drug response
        dataframe, RNA sequence dataframe, and drug feature dataframe are
        sharing the same list of cell lines and drugs.

        Returns:
            None
        """

        # Encode the data source and take the data from target source only
        # Note that source could be 'NCI60', 'GDSC', etc. and 'all'
        t_total_time = time.time()
        if self.data_source.lower() != 'all':

            logger.debug('Specifying data source %s ... ' % self.data_source)

            data_src_dict = get_label_dict(data_root=self.__data_root,
                                           dict_name='data_src_dict.txt')
            encoded_data_src = data_src_dict[self.data_source]

            # Reduce/trim the drug response dataframe
            self.__drug_resp_df = self.__drug_resp_df.loc[
                self.__drug_resp_df['SOURCE'] == encoded_data_src]

        # Make sure that all three dataframes share the same drugs/cells
        logger.debug('Trimming dataframes on common cell lines and drugs ... ')
        t_set_gen_time = time.time()
        cell_set = set(self.__drug_resp_df['CELLNAME'].unique()) \
            & set(self.__rnaseq_df.index.values)
        drug_set = set(self.__drug_resp_df['DRUG_ID'].unique()) \
            & set(self.__drug_feature_df.index.values)
        t_set_gen_time = time.time() - t_set_gen_time
        print(f"set gen time : {t_set_gen_time} s")
        t_isin_loc = time.time()
        self.__drug_resp_df = self.__drug_resp_df.loc[
            (self.__drug_resp_df['CELLNAME'].isin(cell_set))
            & (self.__drug_resp_df['DRUG_ID'].isin(drug_set))]
        self.__rnaseq_df = self.__rnaseq_df[self.__rnaseq_df.index.isin(
            cell_set)]
        self.__drug_feature_df = self.__drug_feature_df[
            self.__drug_feature_df.index.isin(drug_set)]
        t_isin_loc = time.time() - t_isin_loc
        print(f"t_isin_loc op time : {t_isin_loc} s")
        logger.debug('There are %i drugs and %i cell lines, with %i response '
                     'records after trimming.' %
                     (len(drug_set), len(cell_set), len(self.__drug_resp_df)))
        t_total_time = time.time() - t_total_time
        print(f"Time Taken for Trim Operation : {t_total_time} s")
        return
예제 #2
0
        except FileExistsError:
            pass
        df.to_pickle(df_path)

    df = df.astype(int_dtype)
    return df


if __name__ == '__main__':

    logging.basicConfig(level=logging.DEBUG)

    print('=' * 80 + '\nRNA sequence dataframe head:')
    print(
        get_rna_seq_df(data_root='../../data/',
                       rnaseq_feature_usage='source_scale',
                       rnaseq_scaling='std').head())

    print('=' * 80 + '\nCell line metadata dataframe head:')
    print(get_cl_meta_df(data_root='../../data/').head())

    cl_meta_df = get_cl_meta_df(data_root='../../data/')
    data_src_dict = get_label_dict('../../data/', 'data_src_dict.txt')

    for data_src, enc_data_src in data_src_dict.items():
        num_samples = len(
            cl_meta_df.loc[cl_meta_df['data_src'] == enc_data_src])

        print('%6i (%6.3f%%) samples from source %s' %
              (num_samples, 100 * num_samples / len(cl_meta_df), data_src))
예제 #3
0
def main():
    # Training settings and hyper-parameters
    parser = argparse.ArgumentParser(
        description='Multitasking Neural Network for Genes and Drugs')

    # Dataset parameters ######################################################
    # Training and validation data sources
    parser.add_argument('--trn_src',
                        type=str,
                        required=True,
                        help='training source for drug response')
    parser.add_argument('--val_srcs',
                        type=str,
                        required=True,
                        nargs='+',
                        help='validation list of sources for drug response')

    # Pre-processing for dataframes
    parser.add_argument('--grth_scaling',
                        type=str,
                        default='std',
                        help='scaling method for drug response (growth)',
                        choices=SCALING_METHODS)
    parser.add_argument('--dscptr_scaling',
                        type=str,
                        default='std',
                        help='scaling method for drug feature (descriptor)',
                        choices=SCALING_METHODS)
    parser.add_argument('--rnaseq_scaling',
                        type=str,
                        default='std',
                        help='scaling method for RNA sequence',
                        choices=SCALING_METHODS)
    parser.add_argument('--dscptr_nan_threshold',
                        type=float,
                        default=0.0,
                        help='ratio of NaN values allowed for drug descriptor')
    parser.add_argument('--qed_scaling',
                        type=str,
                        default='none',
                        help='scaling method for drug weighted QED',
                        choices=SCALING_METHODS)

    # Feature usage and partitioning settings
    parser.add_argument('--rnaseq_feature_usage',
                        type=str,
                        default='combat',
                        help='RNA sequence data used',
                        choices=[
                            'source_scale',
                            'combat',
                        ])
    parser.add_argument('--drug_feature_usage',
                        type=str,
                        default='both',
                        help='drug features (fp and/or desc) used',
                        choices=[
                            'fingerprint',
                            'descriptor',
                            'both',
                        ])
    parser.add_argument('--validation_ratio',
                        type=float,
                        default=0.2,
                        help='ratio for validation dataset')
    parser.add_argument('--disjoint_drugs',
                        action='store_true',
                        help='disjoint drugs between train/validation')
    parser.add_argument('--disjoint_cells',
                        action='store_true',
                        help='disjoint cells between train/validation')

    # Network configuration ###################################################
    # Encoders for drug features and RNA sequence (LINCS 1000)
    parser.add_argument('--gene_layer_dim',
                        type=int,
                        default=1024,
                        help='dimension of layers for RNA sequence')
    parser.add_argument('--gene_latent_dim',
                        type=int,
                        default=256,
                        help='dimension of latent variable for RNA sequence')
    parser.add_argument('--gene_num_layers',
                        type=int,
                        default=2,
                        help='number of layers for RNA sequence')

    parser.add_argument('--drug_layer_dim',
                        type=int,
                        default=4096,
                        help='dimension of layers for drug feature')
    parser.add_argument('--drug_latent_dim',
                        type=int,
                        default=1024,
                        help='dimension of latent variable for drug feature')
    parser.add_argument('--drug_num_layers',
                        type=int,
                        default=2,
                        help='number of layers for drug feature')

    # Using autoencoder for drug/sequence encoder initialization
    parser.add_argument('--autoencoder_init',
                        action='store_true',
                        help='indicator of autoencoder initialization for '
                        'drug/RNA sequence feature encoder')

    # Drug response regression network
    parser.add_argument('--resp_layer_dim',
                        type=int,
                        default=1024,
                        help='dimension of layers for drug response block')
    parser.add_argument('--resp_num_layers_per_block',
                        type=int,
                        default=2,
                        help='number of layers for drug response res block')
    parser.add_argument('--resp_num_blocks',
                        type=int,
                        default=2,
                        help='number of residual blocks for drug response')
    parser.add_argument('--resp_num_layers',
                        type=int,
                        default=2,
                        help='number of layers for drug response')
    parser.add_argument('--resp_dropout',
                        type=float,
                        default=0.0,
                        help='dropout of residual blocks for drug response')
    parser.add_argument('--resp_activation',
                        type=str,
                        default='none',
                        help='activation for response prediction output',
                        choices=['sigmoid', 'tanh', 'none'])

    # Cell line classification network(s)
    parser.add_argument('--cl_clf_layer_dim',
                        type=int,
                        default=256,
                        help='layer dimension for cell line classification')
    parser.add_argument('--cl_clf_num_layers',
                        type=int,
                        default=1,
                        help='number of layers for cell line classification')

    # Drug target family classification network
    parser.add_argument('--drug_target_layer_dim',
                        type=int,
                        default=512,
                        help='dimension of layers for drug target prediction')
    parser.add_argument('--drug_target_num_layers',
                        type=int,
                        default=2,
                        help='number of layers for drug target prediction')

    # Drug weighted QED regression network
    parser.add_argument('--drug_qed_layer_dim',
                        type=int,
                        default=512,
                        help='dimension of layers for drug QED prediction')
    parser.add_argument('--drug_qed_num_layers',
                        type=int,
                        default=2,
                        help='number of layers for drug QED prediction')
    parser.add_argument('--drug_qed_activation',
                        type=str,
                        default='none',
                        help='activation for drug QED prediction output',
                        choices=['sigmoid', 'tanh', 'none'])

    # Training and validation parameters ######################################
    # Drug response regression training parameters
    parser.add_argument('--resp_loss_func',
                        type=str,
                        default='mse',
                        help='loss function for drug response regression',
                        choices=['mse', 'l1'])
    parser.add_argument('--resp_opt',
                        type=str,
                        default='SGD',
                        help='optimizer for drug response regression',
                        choices=['SGD', 'RMSprop', 'Adam'])
    parser.add_argument('--resp_lr',
                        type=float,
                        default=1e-5,
                        help='learning rate for drug response regression')

    # Drug response uncertainty quantification parameters
    parser.add_argument('--resp_uq',
                        action='store_true',
                        help='indicator of drug response uncertainty '
                        'quantification using dropouts')
    parser.add_argument('--resp_uq_dropout',
                        type=float,
                        default=0.5,
                        help='dropout rate for uncertainty quantification')
    parser.add_argument('--resp_uq_length_scale',
                        type=float,
                        default=1.0,
                        help='Prior length-scale that captures our belief '
                        'over the function frequency')
    parser.add_argument('--resp_uq_num_runs',
                        type=int,
                        default=100,
                        help='number of predictions (runs) for uncertainty '
                        'quantification')

    # Cell line classification training parameters
    parser.add_argument('--cl_clf_opt',
                        type=str,
                        default='SGD',
                        help='optimizer for cell line classification',
                        choices=['SGD', 'RMSprop', 'Adam'])
    parser.add_argument('--cl_clf_lr',
                        type=float,
                        default=1e-3,
                        help='learning rate for cell line classification')

    # Drug target family classification training parameters
    parser.add_argument('--drug_target_opt',
                        type=str,
                        default='SGD',
                        help='optimizer for drug target classification '
                        'training',
                        choices=['SGD', 'RMSprop', 'Adam'])
    parser.add_argument('--drug_target_lr',
                        type=float,
                        default=1e-3,
                        help='learning rate for drug target classification')

    # Drug weighted QED regression training parameters
    parser.add_argument('--drug_qed_loss_func',
                        type=str,
                        default='mse',
                        help='loss function for drug QED regression',
                        choices=['mse', 'l1'])
    parser.add_argument('--drug_qed_opt',
                        type=str,
                        default='SGD',
                        help='optimizer for drug rQED regression',
                        choices=['SGD', 'RMSprop', 'Adam'])
    parser.add_argument('--drug_qed_lr',
                        type=float,
                        default=1e-3,
                        help='learning rate for drug QED regression')

    # Starting epoch for drug response validation
    parser.add_argument('--resp_val_start_epoch',
                        type=int,
                        default=0,
                        help='starting epoch for drug response validation')

    # Early stopping based on R2 score of drug response prediction
    parser.add_argument('--early_stop_patience',
                        type=int,
                        default=5,
                        help='patience for early stopping based on drug '
                        'response validation R2 scores ')

    # Global/shared training parameters
    parser.add_argument('--l2_regularization',
                        type=float,
                        default=1e-5,
                        help='L2 regularization for nn weights')
    parser.add_argument('--lr_decay_factor',
                        type=float,
                        default=0.95,
                        help='decay factor for learning rate')
    parser.add_argument('--trn_batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training')
    parser.add_argument('--val_batch_size',
                        type=int,
                        default=256,
                        help='input batch size for validation')
    parser.add_argument('--max_num_batches',
                        type=int,
                        default=1000,
                        help='maximum number of batches per epoch')
    parser.add_argument('--max_num_epochs',
                        type=int,
                        default=100,
                        help='maximum number of epochs')

    # Validation results directory
    parser.add_argument('--val_results_dir',
                        type=str,
                        default=None,
                        help='directory for saved validation results. '
                        'Set to None to skip results saving')

    # Miscellaneous settings ##################################################
    parser.add_argument('--multi_gpu',
                        action='store_true',
                        default=False,
                        help='enables multiple GPU process')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--rand_state',
                        type=int,
                        default=0,
                        help='random state of numpy/sklearn/pytorch')

    args = parser.parse_args()
    print('Training Arguments:\n' + json.dumps(vars(args), indent=4))

    # Setting up random seed for reproducible and deterministic results
    seed_random_state(args.rand_state)

    # Computation device config (cuda or cpu)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # Data loaders for training/validation ####################################
    dataloader_kwargs = {
        'timeout': 1,
        'shuffle': 'True',
        # 'num_workers': multiprocessing.cpu_count() if use_cuda else 0,
        'num_workers': NUM_WORKER if use_cuda else 0,
        'pin_memory': True if use_cuda else False,
    }

    # Drug response dataloaders for training/validation
    drug_resp_dataset_kwargs = {
        'data_root': DATA_ROOT,
        'rand_state': args.rand_state,
        'summary': False,
        'int_dtype': np.int8,
        'float_dtype': np.float16,
        'output_dtype': np.float32,
        'grth_scaling': args.grth_scaling,
        'dscptr_scaling': args.dscptr_scaling,
        'rnaseq_scaling': args.rnaseq_scaling,
        'dscptr_nan_threshold': args.dscptr_nan_threshold,
        'rnaseq_feature_usage': args.rnaseq_feature_usage,
        'drug_feature_usage': args.drug_feature_usage,
        'validation_ratio': args.validation_ratio,
        'disjoint_drugs': args.disjoint_drugs,
        'disjoint_cells': args.disjoint_cells,
    }

    drug_resp_trn_loader = torch.utils.data.DataLoader(
        DrugRespDataset(data_src=args.trn_src,
                        training=True,
                        **drug_resp_dataset_kwargs),
        batch_size=args.trn_batch_size,
        **dataloader_kwargs)

    # List of data loaders for different validation sets
    drug_resp_val_loaders = [
        torch.utils.data.DataLoader(
            DrugRespDataset(data_src=src,
                            training=False,
                            **drug_resp_dataset_kwargs),
            batch_size=args.val_batch_size,
            **dataloader_kwargs) for src in args.val_srcs
    ]

    # Cell line classification dataloaders for training/validation
    cl_clf_dataset_kwargs = {
        'data_root': DATA_ROOT,
        'rand_state': args.rand_state,
        'summary': False,
        'int_dtype': np.int8,
        'float_dtype': np.float16,
        'output_dtype': np.float32,
        'rnaseq_scaling': args.rnaseq_scaling,
        'rnaseq_feature_usage': args.rnaseq_feature_usage,
        'validation_ratio': args.validation_ratio,
    }

    cl_clf_trn_loader = torch.utils.data.DataLoader(
        CLClassDataset(training=True, **cl_clf_dataset_kwargs),
        batch_size=args.trn_batch_size,
        **dataloader_kwargs)

    cl_clf_val_loader = torch.utils.data.DataLoader(
        CLClassDataset(training=False, **cl_clf_dataset_kwargs),
        batch_size=args.val_batch_size,
        **dataloader_kwargs)

    # Drug target family classification dataloaders for training/validation
    drug_target_dataset_kwargs = {
        'data_root': DATA_ROOT,
        'rand_state': args.rand_state,
        'summary': False,
        'int_dtype': np.int8,
        'float_dtype': np.float16,
        'output_dtype': np.float32,
        'dscptr_scaling': args.dscptr_scaling,
        'dscptr_nan_threshold': args.dscptr_nan_threshold,
        'drug_feature_usage': args.drug_feature_usage,
        'validation_ratio': args.validation_ratio,
    }

    drug_target_trn_loader = torch.utils.data.DataLoader(
        DrugTargetDataset(training=True, **drug_target_dataset_kwargs),
        batch_size=args.trn_batch_size,
        **dataloader_kwargs)

    drug_target_val_loader = torch.utils.data.DataLoader(
        DrugTargetDataset(training=False, **drug_target_dataset_kwargs),
        batch_size=args.val_batch_size,
        **dataloader_kwargs)

    # Drug weighted QED regression dataloaders for training/validation
    drug_qed_dataset_kwargs = {
        'data_root': DATA_ROOT,
        'rand_state': args.rand_state,
        'summary': False,
        'int_dtype': np.int8,
        'float_dtype': np.float16,
        'output_dtype': np.float32,
        'qed_scaling': args.qed_scaling,
        'dscptr_scaling': args.dscptr_scaling,
        'dscptr_nan_threshold': args.dscptr_nan_threshold,
        'drug_feature_usage': args.drug_feature_usage,
        'validation_ratio': args.validation_ratio,
    }

    drug_qed_trn_loader = torch.utils.data.DataLoader(
        DrugQEDDataset(training=True, **drug_qed_dataset_kwargs),
        batch_size=args.trn_batch_size,
        **dataloader_kwargs)

    drug_qed_val_loader = torch.utils.data.DataLoader(
        DrugQEDDataset(training=False, **drug_qed_dataset_kwargs),
        batch_size=args.val_batch_size,
        **dataloader_kwargs)

    # Constructing and initializing neural networks ###########################
    # Autoencoder training hyper-parameters
    ae_training_kwarg = {
        'ae_loss_func': 'mse',
        'ae_opt': 'sgd',
        'ae_lr': 2e-1,
        'lr_decay_factor': 1.0,
        'max_num_epochs': 1000,
        'early_stop_patience': 50,
    }

    encoder_kwarg = {
        'model_folder': './models/',
        'data_root': DATA_ROOT,
        'autoencoder_init': args.autoencoder_init,
        'training_kwarg': ae_training_kwarg,
        'device': device,
        'verbose': True,
        'rand_state': args.rand_state,
    }

    # Get RNA sequence encoder
    gene_encoder = get_gene_encoder(
        rnaseq_feature_usage=args.rnaseq_feature_usage,
        rnaseq_scaling=args.rnaseq_scaling,
        layer_dim=args.gene_layer_dim,
        num_layers=args.gene_num_layers,
        latent_dim=args.gene_latent_dim,
        **encoder_kwarg)

    # Get drug feature encoder
    drug_encoder = get_drug_encoder(
        drug_feature_usage=args.drug_feature_usage,
        dscptr_scaling=args.dscptr_scaling,
        dscptr_nan_threshold=args.dscptr_nan_threshold,
        layer_dim=args.drug_layer_dim,
        num_layers=args.drug_num_layers,
        latent_dim=args.drug_latent_dim,
        **encoder_kwarg)

    # Regressor for drug response
    resp_net = RespNet(
        gene_latent_dim=args.gene_latent_dim,
        drug_latent_dim=args.drug_latent_dim,
        gene_encoder=gene_encoder,
        drug_encoder=drug_encoder,
        resp_layer_dim=args.resp_layer_dim,
        resp_num_layers_per_block=args.resp_num_layers_per_block,
        resp_num_blocks=args.resp_num_blocks,
        resp_num_layers=args.resp_num_layers,
        resp_dropout=args.resp_dropout,
        resp_activation=args.resp_activation).to(device)

    print(resp_net)

    # Sequence classifier for category, site, and type
    cl_clf_net_kwargs = {
        'encoder': gene_encoder,
        'input_dim': args.gene_latent_dim,
        'condition_dim': len(get_label_dict(DATA_ROOT, 'data_src_dict.txt')),
        'layer_dim': args.cl_clf_layer_dim,
        'num_layers': args.cl_clf_num_layers,
    }

    category_clf_net = ClfNet(num_classes=len(
        get_label_dict(DATA_ROOT, 'category_dict.txt')),
                              **cl_clf_net_kwargs).to(device)
    site_clf_net = ClfNet(num_classes=len(
        get_label_dict(DATA_ROOT, 'site_dict.txt')),
                          **cl_clf_net_kwargs).to(device)
    type_clf_net = ClfNet(num_classes=len(
        get_label_dict(DATA_ROOT, 'type_dict.txt')),
                          **cl_clf_net_kwargs).to(device)

    # Classifier for drug target family prediction
    drug_target_net = ClfNet(
        encoder=drug_encoder,
        input_dim=args.drug_latent_dim,
        condition_dim=0,
        layer_dim=args.drug_target_layer_dim,
        num_layers=args.drug_target_num_layers,
        num_classes=len(get_label_dict(DATA_ROOT, 'drug_target_dict.txt'))).\
        to(device)

    # Regressor for drug weighted QED prediction
    drug_qed_net = RgsNet(encoder=drug_encoder,
                          input_dim=args.drug_latent_dim,
                          condition_dim=0,
                          layer_dim=args.drug_qed_layer_dim,
                          num_layers=args.drug_qed_num_layers,
                          activation=args.drug_qed_activation).to(device)

    # Multi-GPU settings
    if args.multi_gpu:
        resp_net = nn.DataParallel(resp_net)
        category_clf_net = nn.DataParallel(category_clf_net)
        site_clf_net = nn.DataParallel(site_clf_net)
        type_clf_net = nn.DataParallel(type_clf_net)
        drug_target_net = nn.DataParallel(drug_target_net)
        drug_qed_net = nn.DataParallel(drug_qed_net)

    # Optimizers, learning rate decay, and miscellaneous ######################
    resp_opt = get_optimizer(opt_type=args.resp_opt,
                             networks=resp_net,
                             learning_rate=args.resp_lr,
                             l2_regularization=args.l2_regularization)
    cl_clf_opt = get_optimizer(
        opt_type=args.cl_clf_opt,
        networks=[category_clf_net, site_clf_net, type_clf_net],
        learning_rate=args.cl_clf_lr,
        l2_regularization=args.l2_regularization)
    drug_target_opt = get_optimizer(opt_type=args.drug_target_opt,
                                    networks=drug_target_net,
                                    learning_rate=args.drug_target_lr,
                                    l2_regularization=args.l2_regularization)
    drug_qed_opt = get_optimizer(opt_type=args.drug_qed_opt,
                                 networks=drug_qed_net,
                                 learning_rate=args.drug_qed_lr,
                                 l2_regularization=args.l2_regularization)

    resp_lr_decay = LambdaLR(optimizer=resp_opt,
                             lr_lambda=lambda e: args.lr_decay_factor**e)
    cl_clf_lr_decay = LambdaLR(optimizer=cl_clf_opt,
                               lr_lambda=lambda e: args.lr_decay_factor**e)
    drug_target_lr_decay = LambdaLR(
        optimizer=drug_target_opt, lr_lambda=lambda e: args.lr_decay_factor**e)
    drug_qed_lr_decay = LambdaLR(optimizer=drug_qed_opt,
                                 lr_lambda=lambda e: args.lr_decay_factor**e)

    resp_loss_func = F.l1_loss if args.resp_loss_func == 'l1' \
        else F.mse_loss
    drug_qed_loss_func = F.l1_loss if args.drug_qed_loss_func == 'l1' \
        else F.mse_loss

    # Training/validation loops ###############################################
    val_cl_clf_acc = []
    val_drug_target_acc = []
    val_drug_qed_mse, val_drug_qed_mae, val_drug_qed_r2 = [], [], []
    val_resp_mse, val_resp_mae, val_resp_r2 = [], [], []
    best_r2 = -np.inf
    patience = 0
    start_time = time.time()

    # Create folder for validation results if not exist
    if args.val_results_dir.lower() != 'none':
        try:
            os.makedirs(args.val_results_dir)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise
    else:
        args.val_results_dir = None

    # Early stopping is decided on the validation set with the same
    # data source as the training dataloader
    val_index = 0
    for idx, loader in enumerate(drug_resp_val_loaders):
        if loader.dataset.data_source == args.trn_src:
            val_index = idx

    for epoch in range(args.max_num_epochs):

        print('=' * 80 + '\nTraining Epoch %3i:' % (epoch + 1))
        epoch_start_time = time.time()

        resp_lr_decay.step(epoch)
        cl_clf_lr_decay.step(epoch)
        drug_target_lr_decay.step(epoch)
        drug_qed_lr_decay.step(epoch)

        # Training cell line classifier
        train_cl_clf(device=device,
                     category_clf_net=category_clf_net,
                     site_clf_net=site_clf_net,
                     type_clf_net=type_clf_net,
                     data_loader=cl_clf_trn_loader,
                     max_num_batches=args.max_num_batches,
                     optimizer=cl_clf_opt)

        # Training drug target classifier
        train_drug_target(device=device,
                          drug_target_net=drug_target_net,
                          data_loader=drug_target_trn_loader,
                          max_num_batches=args.max_num_batches,
                          optimizer=drug_target_opt)

        # Training drug weighted QED regressor
        train_drug_qed(device=device,
                       drug_qed_net=drug_qed_net,
                       data_loader=drug_qed_trn_loader,
                       max_num_batches=args.max_num_batches,
                       loss_func=drug_qed_loss_func,
                       optimizer=drug_qed_opt)

        # Training drug response regressor
        train_resp(device=device,
                   resp_net=resp_net,
                   data_loader=drug_resp_trn_loader,
                   max_num_batches=args.max_num_batches,
                   loss_func=resp_loss_func,
                   optimizer=resp_opt)

        print('\nValidation Results:')

        if epoch >= args.resp_val_start_epoch:

            # Validating cell line classifier
            cl_category_acc, cl_site_acc, cl_type_acc = \
                valid_cl_clf(device=device,
                             category_clf_net=category_clf_net,
                             site_clf_net=site_clf_net,
                             type_clf_net=type_clf_net,
                             data_loader=cl_clf_val_loader, )
            val_cl_clf_acc.append([cl_category_acc, cl_site_acc, cl_type_acc])

            # Validating drug target classifier
            drug_target_acc = \
                valid_drug_target(device=device,
                                  drug_target_net=drug_target_net,
                                  data_loader=drug_target_val_loader)
            val_drug_target_acc.append(drug_target_acc)

            # Validating drug weighted QED regressor
            drug_qed_mse, drug_qed_mae, drug_qed_r2 = \
                valid_drug_qed(device=device,
                               drug_qed_net=drug_qed_net,
                               data_loader=drug_qed_val_loader)
            val_drug_qed_mse.append(drug_qed_mse)
            val_drug_qed_mae.append(drug_qed_mae)
            val_drug_qed_r2.append(drug_qed_r2)

            # Validating drug response regressor
            resp_mse, resp_mae, resp_r2 = \
                valid_resp(epoch=epoch,
                           trn_src=args.trn_src,
                           device=device,

                           resp_net=resp_net,
                           data_loaders=drug_resp_val_loaders,

                           resp_uq=args.resp_uq,
                           resp_uq_dropout=args.resp_uq_dropout,
                           resp_uq_num_runs=args.resp_uq_num_runs,

                           val_results_dir=args.val_results_dir)

            # Save the validation results in nested list
            val_resp_mse.append(resp_mse)
            val_resp_mae.append(resp_mae)
            val_resp_r2.append(resp_r2)

            # Record the best R2 score (same data source)
            # and check for early stopping if no improvement for epochs
            if resp_r2[val_index] > best_r2:
                patience = 0
                best_r2 = resp_r2[val_index]
            else:
                patience += 1
            if patience >= args.early_stop_patience:
                print('Validation results does not improve for %d epochs ... '
                      'invoking early stopping.' % patience)
                break

        print('Epoch Running Time: %.1f Seconds.' %
              (time.time() - epoch_start_time))

    val_cl_clf_acc = np.array(val_cl_clf_acc).reshape(-1, 3)
    # val_drug_target_acc = np.array(val_drug_target_acc)
    # val_drug_qed_mse = np.array(val_drug_qed_mse)
    # val_resp_mae = np.array(val_resp_mae)
    # val_resp_r2 = np.array(val_resp_r2)
    val_resp_mse, val_resp_mae, val_resp_r2 = \
        np.array(val_resp_mse).reshape(-1, len(args.val_srcs)), \
        np.array(val_resp_mae).reshape(-1, len(args.val_srcs)), \
        np.array(val_resp_r2).reshape(-1, len(args.val_srcs))

    print('Program Running Time: %.1f Seconds.' % (time.time() - start_time))

    # Print overall validation results
    print('=' * 80)
    print('Overall Validation Results:\n')

    print('\tBest Results from Different Models (Epochs):')
    # Print best accuracy for cell line classifiers
    clf_targets = [
        'Cell Line Categories',
        'Cell Line Sites',
        'Cell Line Types',
    ]
    best_acc = np.amax(val_cl_clf_acc, axis=0)
    best_acc_epochs = np.argmax(val_cl_clf_acc, axis=0)

    for index, clf_target in enumerate(clf_targets):
        print('\t\t%-24s Best Accuracy: %.3f%% (Epoch = %3d)' %
              (clf_target, best_acc[index],
               best_acc_epochs[index] + 1 + args.resp_val_start_epoch))

    # Print best predictions for drug classifiers and regressor
    print('\t\tDrug Target Family \t Best Accuracy: %.3f%% (Epoch = %3d)' %
          (np.max(val_drug_target_acc),
           (np.argmax(val_drug_target_acc) + 1 + args.resp_val_start_epoch)))

    print('\t\tDrug Weighted QED \t Best R2 Score: %+6.4f '
          '(Epoch = %3d, MSE = %8.6f, MAE = %8.6f)' %
          (np.max(val_drug_qed_r2),
           (np.argmax(val_drug_qed_r2) + 1 + args.resp_val_start_epoch),
           val_drug_qed_mse[np.argmax(val_drug_qed_r2)],
           val_drug_qed_mae[np.argmax(val_drug_qed_r2)]))

    # Print best R2 scores for drug response regressor
    val_data_sources = \
        [loader.dataset.data_source for loader in drug_resp_val_loaders]
    best_r2 = np.amax(val_resp_r2, axis=0)
    best_r2_epochs = np.argmax(val_resp_r2, axis=0)

    for index, data_source in enumerate(val_data_sources):
        print('\t\t%-6s \t Best R2 Score: %+6.4f '
              '(Epoch = %3d, MSE = %8.2f, MAE = %6.2f)' %
              (data_source, best_r2[index],
               best_r2_epochs[index] + args.resp_val_start_epoch + 1,
               val_resp_mse[best_r2_epochs[index], index],
               val_resp_mae[best_r2_epochs[index], index]))

    # Print best epoch and all the corresponding validation results
    # Picking the best epoch using R2 score from same data source
    best_epoch = val_resp_r2[:, val_index].argmax()
    print('\n\tBest Results from the Same Model (Epoch = %3d):' %
          (best_epoch + 1 + args.resp_val_start_epoch))
    for index, clf_target in enumerate(clf_targets):
        print('\t\t%-24s Accuracy: %.3f%%' %
              (clf_target, val_cl_clf_acc[best_epoch, index]))

    # Print best predictions for drug classifiers and regressor
    print('\t\tDrug Target Family \t Accuracy: %.3f%% ' %
          (val_drug_target_acc[best_epoch]))

    print('\t\tDrug Weighted QED \t R2 Score: %+6.4f '
          '(MSE = %8.6f, MAE = %6.6f)' %
          (val_drug_qed_r2[best_epoch], val_drug_qed_mse[best_epoch],
           val_drug_qed_mae[best_epoch]))

    for index, data_source in enumerate(val_data_sources):
        print(
            '\t\t%-6s \t R2 Score: %+6.4f '
            '(MSE = %8.2f, MAE = %6.2f)' %
            (data_source, val_resp_r2[best_epoch, index],
             val_resp_mse[best_epoch, index], val_resp_mae[best_epoch, index]))
예제 #4
0
    def __init__(
        self,
        data_root: str,
        training: bool,
        rand_state: int = 0,
        summary: bool = True,

        # Data type settings (for storage and data loading)
        int_dtype: type = np.int8,
        float_dtype: type = np.float16,
        output_dtype: type = np.float32,

        # Pre-processing settings
        rnaseq_scaling: str = 'std',
        predict_target: str = 'class',

        # Partitioning (train/validation) and data usage settings
        rnaseq_feature_usage: str = 'source_scale',
        validation_ratio: float = 0.2,
    ):
        """dataset = CLClassDataset('./data/', True)

        Construct a RNA sequence dataset based on the parameters provided.
        The process includes:
            * Downloading source data files;
            * Pre-processing (scaling);
            * Public attributes and other preparations.

        Args:
            data_root (str): path to data root folder.
            training (bool): indicator for training.
            rand_state (int): random seed used for training/validation split
                and other processes that requires randomness.
            summary (bool): set True for printing dataset summary.

            int_dtype (type): integer dtype for data storage in RAM.
            float_dtype (type): float dtype for data storage in RAM.
            output_dtype (type): output dtype for neural network.

            rnaseq_scaling (str): scaling method for RNA sequence. Choose
                between 'none', 'std', and 'minmax'.
            predict_target (str): prediction target for RNA sequence. Note
                that any labels except for target will be in one-hot
                encoding, while the target will be encoded as integers.
                Choose between 'none', 'class', and 'source'.

            rnaseq_feature_usage: RNA sequence data usage. Choose between
                'source_scale' and 'combat'.
            validation_ratio (float): portion of validation data out of all
                data samples.
        """

        # Initialization ######################################################
        self.__data_root = data_root

        # Class-wise variables
        self.training = training
        self.__rand_state = rand_state
        self.__output_dtype = output_dtype

        # Feature scaling
        if rnaseq_scaling is None or rnaseq_scaling == '':
            rnaseq_scaling = 'none'
        self.__rnaseq_scaling = rnaseq_scaling.lower()
        if predict_target is None or predict_target == '':
            predict_target = 'none'
        assert predict_target.lower() in ['none', 'class', 'source']
        self.__predict_target = predict_target.lower()

        self.__rnaseq_feature_usage = rnaseq_feature_usage
        self.__validation_ratio = validation_ratio

        # Load all dataframes #################################################
        self.__rnaseq_df = get_rna_seq_df(
            data_root=data_root,
            rnaseq_feature_usage=rnaseq_feature_usage,
            rnaseq_scaling=rnaseq_scaling,
            float_dtype=float_dtype)

        self.__cl_meta_df = get_cl_meta_df(data_root=data_root,
                                           int_dtype=int_dtype)

        # Put all the sequence in one column as list and specify dtype
        self.__rnaseq_df['seq'] = \
            list(map(float_dtype, self.__rnaseq_df.values.tolist()))

        # Join the RNA sequence data with meta data. cl_df will have columns:
        # ['data_src', 'site', 'type', 'category', 'seq']
        self.__cl_df = pd.concat(
            [self.__cl_meta_df, self.__rnaseq_df[['seq']]],
            axis=1,
            join='inner')

        # Exclude 'GDC' and 'NCI60' during data source prediction
        # GDC has too many samples while NCI60 has not enough
        if self.__predict_target == 'source':
            logger.warning('Taking out GDC and NCI60 samples to make dataset '
                           'balanced among all data sources ...')
            self.__cl_df = self.__cl_df[~self.__cl_df['data_src'].isin([2, 5])]

        # Encode labels (except for prediction targets) into one-hot encoding
        if self.__predict_target != 'source':
            enc_data_src = encode_int_to_onehot(
                self.__cl_df['data_src'].tolist(),
                len(get_label_dict(data_root, 'data_src_dict.txt')))
            self.__cl_df['data_src'] = list(map(int_dtype, enc_data_src))

        if self.__predict_target != 'class':
            for label in ['site', 'type', 'category']:
                enc_label = encode_int_to_onehot(
                    self.__cl_df[label].tolist(),
                    len(get_label_dict(data_root, '%s_dict.txt' % label)))
                self.__cl_df[label] = list(map(int_dtype, enc_label))

        # Train/validation split ##############################################
        self.__split_drug_resp()

        # Converting dataframes to arrays for rapid access ####################
        self.__cl_array = self.__cl_df.values

        # Public attributes ###################################################
        self.cells = self.__cl_df.index.tolist()
        self.num_cells = self.__cl_df.shape[0]
        self.rnaseq_dim = len(self.__cl_df.iloc[0]['seq'])

        # Clear the dataframes ################################################
        self.__rnaseq_df = None
        self.__cl_meta_df = None
        self.__cl_df = None

        # Dataset summary #####################################################
        if summary:
            print('=' * 80)
            print(('Training' if self.training else 'Validation') +
                  ' RNA Sequence Dataset Summary:')
            print('\t%i Unique Cell Lines (feature dim: %4i).' %
                  (self.num_cells, self.rnaseq_dim))
            print('=' * 80)
예제 #5
0
def main():
    # Training settings and hyper-parameters
    parser = argparse.ArgumentParser(
        description='Data Source (Batch) Prediction for Cell Lines')

    # Dataset parameters ######################################################
    # Pre-processing for dataframes
    parser.add_argument('--rnaseq_scaling',
                        type=str,
                        default='std',
                        help='scaling method for RNA sequence',
                        choices=SCALING_METHODS)

    # Feature usage and partitioning settings
    parser.add_argument('--rnaseq_feature_usage',
                        type=str,
                        default='combat',
                        help='RNA sequence data used',
                        choices=[
                            'source_scale',
                            'combat',
                        ])
    parser.add_argument('--validation_ratio',
                        type=float,
                        default=0.2,
                        help='ratio for validation dataset')

    # Network configuration ###################################################
    parser.add_argument('--layer_dim',
                        type=int,
                        default=256,
                        help='dimension of layers for RNA sequence')
    parser.add_argument('--num_layers',
                        type=int,
                        default=4,
                        help='number of layers for RNA sequence')

    # Training and validation parameters ######################################
    parser.add_argument('--opt',
                        type=str,
                        default='SGD',
                        help='optimizer for data source prediction',
                        choices=['SGD', 'RMSprop', 'Adam'])
    parser.add_argument('--lr',
                        type=float,
                        default=1e-2,
                        help='learning rate for data source prediction')

    # Starting epoch for validation
    parser.add_argument('--val_start_epoch',
                        type=int,
                        default=0,
                        help='starting epoch for data source prediction')

    # Early stopping based on data source prediction accuracy
    parser.add_argument('--early_stop_patience',
                        type=int,
                        default=50,
                        help='patience for early stopping based on data '
                        'source prediction accuracy')

    # Global/shared training parameters
    parser.add_argument('--l2_regularization',
                        type=float,
                        default=0.,
                        help='L2 regularization for nn weights')
    parser.add_argument('--lr_decay_factor',
                        type=float,
                        default=0.98,
                        help='decay factor for learning rate')
    parser.add_argument('--trn_batch_size',
                        type=int,
                        default=32,
                        help='input batch size for training')
    parser.add_argument('--val_batch_size',
                        type=int,
                        default=256,
                        help='input batch size for validation')
    parser.add_argument('--max_num_batches',
                        type=int,
                        default=10000,
                        help='maximum number of batches per epoch')
    parser.add_argument('--max_num_epochs',
                        type=int,
                        default=1000,
                        help='maximum number of epochs')

    # Miscellaneous settings ##################################################
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--rand_state',
                        type=int,
                        default=0,
                        help='random state of numpy/sklearn/pytorch')

    args = parser.parse_args()
    print('Training Arguments:\n' + json.dumps(vars(args), indent=4))

    # Setting up random seed for reproducible and deterministic results
    seed_random_state(args.rand_state)

    # Computation device config (cuda or cpu)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # Data loaders for training/validation ####################################
    dataloader_kwargs = {
        'timeout': 1,
        'shuffle': 'True',
        # 'num_workers': multiprocessing.cpu_count() if use_cuda else 0,
        'num_workers': NUM_WORKER if use_cuda else 0,
        'pin_memory': True if use_cuda else False,
    }

    # Drug response dataloaders for training/validation
    cl_clf_dataset_kwargs = {
        'data_root': DATA_ROOT,
        'rand_state': args.rand_state,
        'summary': False,
        'int_dtype': np.int8,
        'float_dtype': np.float16,
        'output_dtype': np.float32,
        'rnaseq_scaling': args.rnaseq_scaling,
        'predict_target': 'source',
        'rnaseq_feature_usage': args.rnaseq_feature_usage,
        'validation_ratio': args.validation_ratio,
    }

    cl_clf_trn_loader = torch.utils.data.DataLoader(
        CLClassDataset(training=True, **cl_clf_dataset_kwargs),
        batch_size=args.trn_batch_size,
        **dataloader_kwargs)

    cl_clf_val_loader = torch.utils.data.DataLoader(
        CLClassDataset(training=False, **cl_clf_dataset_kwargs),
        batch_size=args.val_batch_size,
        **dataloader_kwargs)

    # Constructing and initializing neural networks ###########################
    net = nn.Sequential()

    prev_dim = cl_clf_trn_loader.dataset.rnaseq_dim
    for label in ['site', 'type', 'category']:
        prev_dim += len(get_label_dict(DATA_ROOT, '%s_dict.txt' % label))

    # net.add_module('dense_%d' % 0, nn.Linear(prev_dim, args.layer_dim))

    for i in range(args.num_layers):
        # net.add_module('residual_block_%d' % i,
        #                ResBlock(layer_dim=args.layer_dim,
        #                         num_layers=2,
        #                         dropout=0.))

        net.add_module('dense_%d' % i, nn.Linear(prev_dim, args.layer_dim))
        net.add_module('dropout_%d' % i, nn.Dropout(0.2))
        prev_dim = args.layer_dim
        net.add_module('relu_%d' % i, nn.ReLU())

    num_data_src = len(get_label_dict(DATA_ROOT, 'data_src_dict.txt'))
    net.add_module('dense', nn.Linear(args.layer_dim, num_data_src))
    net.add_module('logsoftmax', nn.LogSoftmax(dim=1))
    net.apply(basic_weight_init)
    net.to(device)

    print(net)

    # Optimizers, learning rate decay, and miscellaneous ######################
    opt = get_optimizer(opt_type=args.opt,
                        networks=net,
                        learning_rate=args.lr,
                        l2_regularization=args.l2_regularization)
    lr_decay = LambdaLR(optimizer=opt,
                        lr_lambda=lambda e: args.lr_decay_factor**e)

    # Training/validation loops ###############################################
    val_acc = []
    best_acc = 0.
    patience = 0
    start_time = time.time()

    for epoch in range(args.max_num_epochs):

        print('=' * 80 + '\nTraining Epoch %3i:' % (epoch + 1))
        epoch_start_time = time.time()

        lr_decay.step(epoch)

        # Training loop #######################################################
        net.train()

        for batch_idx, (rnaseq, data_src, cl_site, cl_type, cl_category) \
                in enumerate(cl_clf_trn_loader):

            if batch_idx >= args.max_num_batches:
                break

            rnaseq, data_src, cl_site, cl_type, cl_category = \
                rnaseq.to(device), data_src.to(device), cl_site.to(device), \
                cl_type.to(device), cl_category.to(device)

            net.zero_grad()

            out_data_src = net(
                torch.cat((rnaseq, cl_site, cl_type, cl_category), dim=1))

            F.nll_loss(input=out_data_src, target=data_src).backward()

            opt.step()

        # Validation loop #####################################################
        net.eval()

        correct_data_src = 0
        with torch.no_grad():
            for rnaseq, data_src, cl_site, cl_type, cl_category \
                    in cl_clf_val_loader:

                rnaseq, data_src, cl_site, cl_type, cl_category = \
                    rnaseq.to(device), data_src.to(device), \
                    cl_site.to(device), cl_type.to(device), \
                    cl_category.to(device)

                out_data_src = net(
                    torch.cat((rnaseq, cl_site, cl_type, cl_category), dim=1))

                pred_data_src = out_data_src.max(1, keepdim=True)[1]

                # print(data_src)
                # print(pred_data_src)

                correct_data_src += pred_data_src.eq(
                    data_src.view_as(pred_data_src)).sum().item()

        data_src_acc = 100. * correct_data_src / len(cl_clf_val_loader.dataset)

        print(
            '\tCell Line Data Source (Batch) Prediction Accuracy: %5.2f%%; ' %
            data_src_acc)

        # Results recording and early stopping
        val_acc.append(data_src_acc)

        if data_src_acc > best_acc:
            patience = 0
            best_acc = data_src_acc
        else:
            patience += 1
        if patience >= args.early_stop_patience:
            print('Validation accuracy does not improve for %d epochs ... '
                  'invoking early stopping.' % patience)
            break

        print('Epoch Running Time: %.1f Seconds.' %
              (time.time() - epoch_start_time))

    print('Program Running Time: %.1f Seconds.' % (time.time() - start_time))
    print('Best Cell Line Data Source (Batch) Prediction Accuracy: %5.2f%%; ' %
          np.amax(val_acc))
예제 #6
0
    def build_nn(self):

        args = self.args
        device = self.device

        # Constructing and initializing neural networks ###########################
        # Autoencoder training hyper-parameters
        self.ae_training_kwarg = {
            'ae_loss_func': 'mse',
            'ae_opt': 'sgd',
            'ae_lr': 2e-1,
            'ae_reg': 1e-5,
            'lr_decay_factor': 1.0,
            'max_num_epochs': 5,
            'early_stop_patience': 50,
        }

        self.encoder_kwarg = {
            'model_folder': './models/',
            'data_root': DATA_ROOT,
            'autoencoder_init': args.autoencoder_init,
            'training_kwarg': self.ae_training_kwarg,
            'device': device,
            'verbose': True,
            'rand_state': args.rng_seed,
        }

        # Get RNA sequence encoder
        self.gene_encoder = get_gene_encoder(
            rnaseq_feature_usage=args.rnaseq_feature_usage,
            rnaseq_scaling=args.rnaseq_scaling,
            layer_dim=args.gene_layer_dim,
            num_layers=args.gene_num_layers,
            latent_dim=args.gene_latent_dim,
            **(self.encoder_kwarg))

        # Get drug feature encoder
        self.drug_encoder = get_drug_encoder(
            drug_feature_usage=args.drug_feature_usage,
            dscptr_scaling=args.dscptr_scaling,
            dscptr_nan_threshold=args.dscptr_nan_threshold,
            layer_dim=args.drug_layer_dim,
            num_layers=args.drug_num_layers,
            latent_dim=args.drug_latent_dim,
            **(self.encoder_kwarg))

        # Regressor for drug response
        self.resp_net = RespNet(
            gene_latent_dim=args.gene_latent_dim,
            drug_latent_dim=args.drug_latent_dim,
            gene_encoder=self.gene_encoder,
            drug_encoder=self.drug_encoder,
            resp_layer_dim=args.resp_layer_dim,
            resp_num_layers_per_block=args.resp_num_layers_per_block,
            resp_num_blocks=args.resp_num_blocks,
            resp_num_layers=args.resp_num_layers,
            resp_dropout=args.dropout,
            resp_activation=args.resp_activation).to(device)

        print(self.resp_net)

        # Sequence classifier for category, site, and type
        self.cl_clf_net_kwargs = {
            'encoder': self.gene_encoder,
            'input_dim': args.gene_latent_dim,
            'condition_dim': len(get_label_dict(DATA_ROOT,
                                                'data_src_dict.txt')),
            'layer_dim': args.cl_clf_layer_dim,
            'num_layers': args.cl_clf_num_layers,
        }

        self.category_clf_net = ClfNet(num_classes=len(
            get_label_dict(DATA_ROOT, 'category_dict.txt')),
                                       **(self.cl_clf_net_kwargs)).to(device)
        self.site_clf_net = ClfNet(num_classes=len(
            get_label_dict(DATA_ROOT, 'site_dict.txt')),
                                   **(self.cl_clf_net_kwargs)).to(device)
        self.type_clf_net = ClfNet(num_classes=len(
            get_label_dict(DATA_ROOT, 'type_dict.txt')),
                                   **(self.cl_clf_net_kwargs)).to(device)

        # Classifier for drug target family prediction
        self.drug_target_net = ClfNet(
                                    encoder=self.drug_encoder,
                                    input_dim=args.drug_latent_dim,
                                    condition_dim=0,
                                    layer_dim=args.drug_target_layer_dim,
                                    num_layers=args.drug_target_num_layers,
                                    num_classes=len(get_label_dict(DATA_ROOT, 'drug_target_dict.txt'))).\
                                    to(device)

        # Regressor for drug weighted QED prediction
        self.drug_qed_net = RgsNet(
            encoder=self.drug_encoder,
            input_dim=args.drug_latent_dim,
            condition_dim=0,
            layer_dim=args.drug_qed_layer_dim,
            num_layers=args.drug_qed_num_layers,
            activation=args.drug_qed_activation).to(device)

        # Multi-GPU settings
        if args.multi_gpu:
            resp_net = nn.DataParallel(self.resp_net)
            category_clf_net = nn.DataParallel(self.category_clf_net)
            site_clf_net = nn.DataParallel(self.site_clf_net)
            type_clf_net = nn.DataParallel(self.type_clf_net)
            drug_target_net = nn.DataParallel(self.drug_target_net)
            drug_qed_net = nn.DataParallel(self.drug_qed_net)