def __trim_dataframes(self): """self.__trim_dataframes(trim_data_source=True) This function trims three dataframes to make sure that drug response dataframe, RNA sequence dataframe, and drug feature dataframe are sharing the same list of cell lines and drugs. Returns: None """ # Encode the data source and take the data from target source only # Note that source could be 'NCI60', 'GDSC', etc. and 'all' t_total_time = time.time() if self.data_source.lower() != 'all': logger.debug('Specifying data source %s ... ' % self.data_source) data_src_dict = get_label_dict(data_root=self.__data_root, dict_name='data_src_dict.txt') encoded_data_src = data_src_dict[self.data_source] # Reduce/trim the drug response dataframe self.__drug_resp_df = self.__drug_resp_df.loc[ self.__drug_resp_df['SOURCE'] == encoded_data_src] # Make sure that all three dataframes share the same drugs/cells logger.debug('Trimming dataframes on common cell lines and drugs ... ') t_set_gen_time = time.time() cell_set = set(self.__drug_resp_df['CELLNAME'].unique()) \ & set(self.__rnaseq_df.index.values) drug_set = set(self.__drug_resp_df['DRUG_ID'].unique()) \ & set(self.__drug_feature_df.index.values) t_set_gen_time = time.time() - t_set_gen_time print(f"set gen time : {t_set_gen_time} s") t_isin_loc = time.time() self.__drug_resp_df = self.__drug_resp_df.loc[ (self.__drug_resp_df['CELLNAME'].isin(cell_set)) & (self.__drug_resp_df['DRUG_ID'].isin(drug_set))] self.__rnaseq_df = self.__rnaseq_df[self.__rnaseq_df.index.isin( cell_set)] self.__drug_feature_df = self.__drug_feature_df[ self.__drug_feature_df.index.isin(drug_set)] t_isin_loc = time.time() - t_isin_loc print(f"t_isin_loc op time : {t_isin_loc} s") logger.debug('There are %i drugs and %i cell lines, with %i response ' 'records after trimming.' % (len(drug_set), len(cell_set), len(self.__drug_resp_df))) t_total_time = time.time() - t_total_time print(f"Time Taken for Trim Operation : {t_total_time} s") return
except FileExistsError: pass df.to_pickle(df_path) df = df.astype(int_dtype) return df if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) print('=' * 80 + '\nRNA sequence dataframe head:') print( get_rna_seq_df(data_root='../../data/', rnaseq_feature_usage='source_scale', rnaseq_scaling='std').head()) print('=' * 80 + '\nCell line metadata dataframe head:') print(get_cl_meta_df(data_root='../../data/').head()) cl_meta_df = get_cl_meta_df(data_root='../../data/') data_src_dict = get_label_dict('../../data/', 'data_src_dict.txt') for data_src, enc_data_src in data_src_dict.items(): num_samples = len( cl_meta_df.loc[cl_meta_df['data_src'] == enc_data_src]) print('%6i (%6.3f%%) samples from source %s' % (num_samples, 100 * num_samples / len(cl_meta_df), data_src))
def main(): # Training settings and hyper-parameters parser = argparse.ArgumentParser( description='Multitasking Neural Network for Genes and Drugs') # Dataset parameters ###################################################### # Training and validation data sources parser.add_argument('--trn_src', type=str, required=True, help='training source for drug response') parser.add_argument('--val_srcs', type=str, required=True, nargs='+', help='validation list of sources for drug response') # Pre-processing for dataframes parser.add_argument('--grth_scaling', type=str, default='std', help='scaling method for drug response (growth)', choices=SCALING_METHODS) parser.add_argument('--dscptr_scaling', type=str, default='std', help='scaling method for drug feature (descriptor)', choices=SCALING_METHODS) parser.add_argument('--rnaseq_scaling', type=str, default='std', help='scaling method for RNA sequence', choices=SCALING_METHODS) parser.add_argument('--dscptr_nan_threshold', type=float, default=0.0, help='ratio of NaN values allowed for drug descriptor') parser.add_argument('--qed_scaling', type=str, default='none', help='scaling method for drug weighted QED', choices=SCALING_METHODS) # Feature usage and partitioning settings parser.add_argument('--rnaseq_feature_usage', type=str, default='combat', help='RNA sequence data used', choices=[ 'source_scale', 'combat', ]) parser.add_argument('--drug_feature_usage', type=str, default='both', help='drug features (fp and/or desc) used', choices=[ 'fingerprint', 'descriptor', 'both', ]) parser.add_argument('--validation_ratio', type=float, default=0.2, help='ratio for validation dataset') parser.add_argument('--disjoint_drugs', action='store_true', help='disjoint drugs between train/validation') parser.add_argument('--disjoint_cells', action='store_true', help='disjoint cells between train/validation') # Network configuration ################################################### # Encoders for drug features and RNA sequence (LINCS 1000) parser.add_argument('--gene_layer_dim', type=int, default=1024, help='dimension of layers for RNA sequence') parser.add_argument('--gene_latent_dim', type=int, default=256, help='dimension of latent variable for RNA sequence') parser.add_argument('--gene_num_layers', type=int, default=2, help='number of layers for RNA sequence') parser.add_argument('--drug_layer_dim', type=int, default=4096, help='dimension of layers for drug feature') parser.add_argument('--drug_latent_dim', type=int, default=1024, help='dimension of latent variable for drug feature') parser.add_argument('--drug_num_layers', type=int, default=2, help='number of layers for drug feature') # Using autoencoder for drug/sequence encoder initialization parser.add_argument('--autoencoder_init', action='store_true', help='indicator of autoencoder initialization for ' 'drug/RNA sequence feature encoder') # Drug response regression network parser.add_argument('--resp_layer_dim', type=int, default=1024, help='dimension of layers for drug response block') parser.add_argument('--resp_num_layers_per_block', type=int, default=2, help='number of layers for drug response res block') parser.add_argument('--resp_num_blocks', type=int, default=2, help='number of residual blocks for drug response') parser.add_argument('--resp_num_layers', type=int, default=2, help='number of layers for drug response') parser.add_argument('--resp_dropout', type=float, default=0.0, help='dropout of residual blocks for drug response') parser.add_argument('--resp_activation', type=str, default='none', help='activation for response prediction output', choices=['sigmoid', 'tanh', 'none']) # Cell line classification network(s) parser.add_argument('--cl_clf_layer_dim', type=int, default=256, help='layer dimension for cell line classification') parser.add_argument('--cl_clf_num_layers', type=int, default=1, help='number of layers for cell line classification') # Drug target family classification network parser.add_argument('--drug_target_layer_dim', type=int, default=512, help='dimension of layers for drug target prediction') parser.add_argument('--drug_target_num_layers', type=int, default=2, help='number of layers for drug target prediction') # Drug weighted QED regression network parser.add_argument('--drug_qed_layer_dim', type=int, default=512, help='dimension of layers for drug QED prediction') parser.add_argument('--drug_qed_num_layers', type=int, default=2, help='number of layers for drug QED prediction') parser.add_argument('--drug_qed_activation', type=str, default='none', help='activation for drug QED prediction output', choices=['sigmoid', 'tanh', 'none']) # Training and validation parameters ###################################### # Drug response regression training parameters parser.add_argument('--resp_loss_func', type=str, default='mse', help='loss function for drug response regression', choices=['mse', 'l1']) parser.add_argument('--resp_opt', type=str, default='SGD', help='optimizer for drug response regression', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--resp_lr', type=float, default=1e-5, help='learning rate for drug response regression') # Drug response uncertainty quantification parameters parser.add_argument('--resp_uq', action='store_true', help='indicator of drug response uncertainty ' 'quantification using dropouts') parser.add_argument('--resp_uq_dropout', type=float, default=0.5, help='dropout rate for uncertainty quantification') parser.add_argument('--resp_uq_length_scale', type=float, default=1.0, help='Prior length-scale that captures our belief ' 'over the function frequency') parser.add_argument('--resp_uq_num_runs', type=int, default=100, help='number of predictions (runs) for uncertainty ' 'quantification') # Cell line classification training parameters parser.add_argument('--cl_clf_opt', type=str, default='SGD', help='optimizer for cell line classification', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--cl_clf_lr', type=float, default=1e-3, help='learning rate for cell line classification') # Drug target family classification training parameters parser.add_argument('--drug_target_opt', type=str, default='SGD', help='optimizer for drug target classification ' 'training', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--drug_target_lr', type=float, default=1e-3, help='learning rate for drug target classification') # Drug weighted QED regression training parameters parser.add_argument('--drug_qed_loss_func', type=str, default='mse', help='loss function for drug QED regression', choices=['mse', 'l1']) parser.add_argument('--drug_qed_opt', type=str, default='SGD', help='optimizer for drug rQED regression', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--drug_qed_lr', type=float, default=1e-3, help='learning rate for drug QED regression') # Starting epoch for drug response validation parser.add_argument('--resp_val_start_epoch', type=int, default=0, help='starting epoch for drug response validation') # Early stopping based on R2 score of drug response prediction parser.add_argument('--early_stop_patience', type=int, default=5, help='patience for early stopping based on drug ' 'response validation R2 scores ') # Global/shared training parameters parser.add_argument('--l2_regularization', type=float, default=1e-5, help='L2 regularization for nn weights') parser.add_argument('--lr_decay_factor', type=float, default=0.95, help='decay factor for learning rate') parser.add_argument('--trn_batch_size', type=int, default=32, help='input batch size for training') parser.add_argument('--val_batch_size', type=int, default=256, help='input batch size for validation') parser.add_argument('--max_num_batches', type=int, default=1000, help='maximum number of batches per epoch') parser.add_argument('--max_num_epochs', type=int, default=100, help='maximum number of epochs') # Validation results directory parser.add_argument('--val_results_dir', type=str, default=None, help='directory for saved validation results. ' 'Set to None to skip results saving') # Miscellaneous settings ################################################## parser.add_argument('--multi_gpu', action='store_true', default=False, help='enables multiple GPU process') parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--rand_state', type=int, default=0, help='random state of numpy/sklearn/pytorch') args = parser.parse_args() print('Training Arguments:\n' + json.dumps(vars(args), indent=4)) # Setting up random seed for reproducible and deterministic results seed_random_state(args.rand_state) # Computation device config (cuda or cpu) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') # Data loaders for training/validation #################################### dataloader_kwargs = { 'timeout': 1, 'shuffle': 'True', # 'num_workers': multiprocessing.cpu_count() if use_cuda else 0, 'num_workers': NUM_WORKER if use_cuda else 0, 'pin_memory': True if use_cuda else False, } # Drug response dataloaders for training/validation drug_resp_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'grth_scaling': args.grth_scaling, 'dscptr_scaling': args.dscptr_scaling, 'rnaseq_scaling': args.rnaseq_scaling, 'dscptr_nan_threshold': args.dscptr_nan_threshold, 'rnaseq_feature_usage': args.rnaseq_feature_usage, 'drug_feature_usage': args.drug_feature_usage, 'validation_ratio': args.validation_ratio, 'disjoint_drugs': args.disjoint_drugs, 'disjoint_cells': args.disjoint_cells, } drug_resp_trn_loader = torch.utils.data.DataLoader( DrugRespDataset(data_src=args.trn_src, training=True, **drug_resp_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) # List of data loaders for different validation sets drug_resp_val_loaders = [ torch.utils.data.DataLoader( DrugRespDataset(data_src=src, training=False, **drug_resp_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) for src in args.val_srcs ] # Cell line classification dataloaders for training/validation cl_clf_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'rnaseq_scaling': args.rnaseq_scaling, 'rnaseq_feature_usage': args.rnaseq_feature_usage, 'validation_ratio': args.validation_ratio, } cl_clf_trn_loader = torch.utils.data.DataLoader( CLClassDataset(training=True, **cl_clf_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) cl_clf_val_loader = torch.utils.data.DataLoader( CLClassDataset(training=False, **cl_clf_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) # Drug target family classification dataloaders for training/validation drug_target_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'dscptr_scaling': args.dscptr_scaling, 'dscptr_nan_threshold': args.dscptr_nan_threshold, 'drug_feature_usage': args.drug_feature_usage, 'validation_ratio': args.validation_ratio, } drug_target_trn_loader = torch.utils.data.DataLoader( DrugTargetDataset(training=True, **drug_target_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) drug_target_val_loader = torch.utils.data.DataLoader( DrugTargetDataset(training=False, **drug_target_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) # Drug weighted QED regression dataloaders for training/validation drug_qed_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'qed_scaling': args.qed_scaling, 'dscptr_scaling': args.dscptr_scaling, 'dscptr_nan_threshold': args.dscptr_nan_threshold, 'drug_feature_usage': args.drug_feature_usage, 'validation_ratio': args.validation_ratio, } drug_qed_trn_loader = torch.utils.data.DataLoader( DrugQEDDataset(training=True, **drug_qed_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) drug_qed_val_loader = torch.utils.data.DataLoader( DrugQEDDataset(training=False, **drug_qed_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) # Constructing and initializing neural networks ########################### # Autoencoder training hyper-parameters ae_training_kwarg = { 'ae_loss_func': 'mse', 'ae_opt': 'sgd', 'ae_lr': 2e-1, 'lr_decay_factor': 1.0, 'max_num_epochs': 1000, 'early_stop_patience': 50, } encoder_kwarg = { 'model_folder': './models/', 'data_root': DATA_ROOT, 'autoencoder_init': args.autoencoder_init, 'training_kwarg': ae_training_kwarg, 'device': device, 'verbose': True, 'rand_state': args.rand_state, } # Get RNA sequence encoder gene_encoder = get_gene_encoder( rnaseq_feature_usage=args.rnaseq_feature_usage, rnaseq_scaling=args.rnaseq_scaling, layer_dim=args.gene_layer_dim, num_layers=args.gene_num_layers, latent_dim=args.gene_latent_dim, **encoder_kwarg) # Get drug feature encoder drug_encoder = get_drug_encoder( drug_feature_usage=args.drug_feature_usage, dscptr_scaling=args.dscptr_scaling, dscptr_nan_threshold=args.dscptr_nan_threshold, layer_dim=args.drug_layer_dim, num_layers=args.drug_num_layers, latent_dim=args.drug_latent_dim, **encoder_kwarg) # Regressor for drug response resp_net = RespNet( gene_latent_dim=args.gene_latent_dim, drug_latent_dim=args.drug_latent_dim, gene_encoder=gene_encoder, drug_encoder=drug_encoder, resp_layer_dim=args.resp_layer_dim, resp_num_layers_per_block=args.resp_num_layers_per_block, resp_num_blocks=args.resp_num_blocks, resp_num_layers=args.resp_num_layers, resp_dropout=args.resp_dropout, resp_activation=args.resp_activation).to(device) print(resp_net) # Sequence classifier for category, site, and type cl_clf_net_kwargs = { 'encoder': gene_encoder, 'input_dim': args.gene_latent_dim, 'condition_dim': len(get_label_dict(DATA_ROOT, 'data_src_dict.txt')), 'layer_dim': args.cl_clf_layer_dim, 'num_layers': args.cl_clf_num_layers, } category_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'category_dict.txt')), **cl_clf_net_kwargs).to(device) site_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'site_dict.txt')), **cl_clf_net_kwargs).to(device) type_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'type_dict.txt')), **cl_clf_net_kwargs).to(device) # Classifier for drug target family prediction drug_target_net = ClfNet( encoder=drug_encoder, input_dim=args.drug_latent_dim, condition_dim=0, layer_dim=args.drug_target_layer_dim, num_layers=args.drug_target_num_layers, num_classes=len(get_label_dict(DATA_ROOT, 'drug_target_dict.txt'))).\ to(device) # Regressor for drug weighted QED prediction drug_qed_net = RgsNet(encoder=drug_encoder, input_dim=args.drug_latent_dim, condition_dim=0, layer_dim=args.drug_qed_layer_dim, num_layers=args.drug_qed_num_layers, activation=args.drug_qed_activation).to(device) # Multi-GPU settings if args.multi_gpu: resp_net = nn.DataParallel(resp_net) category_clf_net = nn.DataParallel(category_clf_net) site_clf_net = nn.DataParallel(site_clf_net) type_clf_net = nn.DataParallel(type_clf_net) drug_target_net = nn.DataParallel(drug_target_net) drug_qed_net = nn.DataParallel(drug_qed_net) # Optimizers, learning rate decay, and miscellaneous ###################### resp_opt = get_optimizer(opt_type=args.resp_opt, networks=resp_net, learning_rate=args.resp_lr, l2_regularization=args.l2_regularization) cl_clf_opt = get_optimizer( opt_type=args.cl_clf_opt, networks=[category_clf_net, site_clf_net, type_clf_net], learning_rate=args.cl_clf_lr, l2_regularization=args.l2_regularization) drug_target_opt = get_optimizer(opt_type=args.drug_target_opt, networks=drug_target_net, learning_rate=args.drug_target_lr, l2_regularization=args.l2_regularization) drug_qed_opt = get_optimizer(opt_type=args.drug_qed_opt, networks=drug_qed_net, learning_rate=args.drug_qed_lr, l2_regularization=args.l2_regularization) resp_lr_decay = LambdaLR(optimizer=resp_opt, lr_lambda=lambda e: args.lr_decay_factor**e) cl_clf_lr_decay = LambdaLR(optimizer=cl_clf_opt, lr_lambda=lambda e: args.lr_decay_factor**e) drug_target_lr_decay = LambdaLR( optimizer=drug_target_opt, lr_lambda=lambda e: args.lr_decay_factor**e) drug_qed_lr_decay = LambdaLR(optimizer=drug_qed_opt, lr_lambda=lambda e: args.lr_decay_factor**e) resp_loss_func = F.l1_loss if args.resp_loss_func == 'l1' \ else F.mse_loss drug_qed_loss_func = F.l1_loss if args.drug_qed_loss_func == 'l1' \ else F.mse_loss # Training/validation loops ############################################### val_cl_clf_acc = [] val_drug_target_acc = [] val_drug_qed_mse, val_drug_qed_mae, val_drug_qed_r2 = [], [], [] val_resp_mse, val_resp_mae, val_resp_r2 = [], [], [] best_r2 = -np.inf patience = 0 start_time = time.time() # Create folder for validation results if not exist if args.val_results_dir.lower() != 'none': try: os.makedirs(args.val_results_dir) except OSError as e: if e.errno != errno.EEXIST: raise else: args.val_results_dir = None # Early stopping is decided on the validation set with the same # data source as the training dataloader val_index = 0 for idx, loader in enumerate(drug_resp_val_loaders): if loader.dataset.data_source == args.trn_src: val_index = idx for epoch in range(args.max_num_epochs): print('=' * 80 + '\nTraining Epoch %3i:' % (epoch + 1)) epoch_start_time = time.time() resp_lr_decay.step(epoch) cl_clf_lr_decay.step(epoch) drug_target_lr_decay.step(epoch) drug_qed_lr_decay.step(epoch) # Training cell line classifier train_cl_clf(device=device, category_clf_net=category_clf_net, site_clf_net=site_clf_net, type_clf_net=type_clf_net, data_loader=cl_clf_trn_loader, max_num_batches=args.max_num_batches, optimizer=cl_clf_opt) # Training drug target classifier train_drug_target(device=device, drug_target_net=drug_target_net, data_loader=drug_target_trn_loader, max_num_batches=args.max_num_batches, optimizer=drug_target_opt) # Training drug weighted QED regressor train_drug_qed(device=device, drug_qed_net=drug_qed_net, data_loader=drug_qed_trn_loader, max_num_batches=args.max_num_batches, loss_func=drug_qed_loss_func, optimizer=drug_qed_opt) # Training drug response regressor train_resp(device=device, resp_net=resp_net, data_loader=drug_resp_trn_loader, max_num_batches=args.max_num_batches, loss_func=resp_loss_func, optimizer=resp_opt) print('\nValidation Results:') if epoch >= args.resp_val_start_epoch: # Validating cell line classifier cl_category_acc, cl_site_acc, cl_type_acc = \ valid_cl_clf(device=device, category_clf_net=category_clf_net, site_clf_net=site_clf_net, type_clf_net=type_clf_net, data_loader=cl_clf_val_loader, ) val_cl_clf_acc.append([cl_category_acc, cl_site_acc, cl_type_acc]) # Validating drug target classifier drug_target_acc = \ valid_drug_target(device=device, drug_target_net=drug_target_net, data_loader=drug_target_val_loader) val_drug_target_acc.append(drug_target_acc) # Validating drug weighted QED regressor drug_qed_mse, drug_qed_mae, drug_qed_r2 = \ valid_drug_qed(device=device, drug_qed_net=drug_qed_net, data_loader=drug_qed_val_loader) val_drug_qed_mse.append(drug_qed_mse) val_drug_qed_mae.append(drug_qed_mae) val_drug_qed_r2.append(drug_qed_r2) # Validating drug response regressor resp_mse, resp_mae, resp_r2 = \ valid_resp(epoch=epoch, trn_src=args.trn_src, device=device, resp_net=resp_net, data_loaders=drug_resp_val_loaders, resp_uq=args.resp_uq, resp_uq_dropout=args.resp_uq_dropout, resp_uq_num_runs=args.resp_uq_num_runs, val_results_dir=args.val_results_dir) # Save the validation results in nested list val_resp_mse.append(resp_mse) val_resp_mae.append(resp_mae) val_resp_r2.append(resp_r2) # Record the best R2 score (same data source) # and check for early stopping if no improvement for epochs if resp_r2[val_index] > best_r2: patience = 0 best_r2 = resp_r2[val_index] else: patience += 1 if patience >= args.early_stop_patience: print('Validation results does not improve for %d epochs ... ' 'invoking early stopping.' % patience) break print('Epoch Running Time: %.1f Seconds.' % (time.time() - epoch_start_time)) val_cl_clf_acc = np.array(val_cl_clf_acc).reshape(-1, 3) # val_drug_target_acc = np.array(val_drug_target_acc) # val_drug_qed_mse = np.array(val_drug_qed_mse) # val_resp_mae = np.array(val_resp_mae) # val_resp_r2 = np.array(val_resp_r2) val_resp_mse, val_resp_mae, val_resp_r2 = \ np.array(val_resp_mse).reshape(-1, len(args.val_srcs)), \ np.array(val_resp_mae).reshape(-1, len(args.val_srcs)), \ np.array(val_resp_r2).reshape(-1, len(args.val_srcs)) print('Program Running Time: %.1f Seconds.' % (time.time() - start_time)) # Print overall validation results print('=' * 80) print('Overall Validation Results:\n') print('\tBest Results from Different Models (Epochs):') # Print best accuracy for cell line classifiers clf_targets = [ 'Cell Line Categories', 'Cell Line Sites', 'Cell Line Types', ] best_acc = np.amax(val_cl_clf_acc, axis=0) best_acc_epochs = np.argmax(val_cl_clf_acc, axis=0) for index, clf_target in enumerate(clf_targets): print('\t\t%-24s Best Accuracy: %.3f%% (Epoch = %3d)' % (clf_target, best_acc[index], best_acc_epochs[index] + 1 + args.resp_val_start_epoch)) # Print best predictions for drug classifiers and regressor print('\t\tDrug Target Family \t Best Accuracy: %.3f%% (Epoch = %3d)' % (np.max(val_drug_target_acc), (np.argmax(val_drug_target_acc) + 1 + args.resp_val_start_epoch))) print('\t\tDrug Weighted QED \t Best R2 Score: %+6.4f ' '(Epoch = %3d, MSE = %8.6f, MAE = %8.6f)' % (np.max(val_drug_qed_r2), (np.argmax(val_drug_qed_r2) + 1 + args.resp_val_start_epoch), val_drug_qed_mse[np.argmax(val_drug_qed_r2)], val_drug_qed_mae[np.argmax(val_drug_qed_r2)])) # Print best R2 scores for drug response regressor val_data_sources = \ [loader.dataset.data_source for loader in drug_resp_val_loaders] best_r2 = np.amax(val_resp_r2, axis=0) best_r2_epochs = np.argmax(val_resp_r2, axis=0) for index, data_source in enumerate(val_data_sources): print('\t\t%-6s \t Best R2 Score: %+6.4f ' '(Epoch = %3d, MSE = %8.2f, MAE = %6.2f)' % (data_source, best_r2[index], best_r2_epochs[index] + args.resp_val_start_epoch + 1, val_resp_mse[best_r2_epochs[index], index], val_resp_mae[best_r2_epochs[index], index])) # Print best epoch and all the corresponding validation results # Picking the best epoch using R2 score from same data source best_epoch = val_resp_r2[:, val_index].argmax() print('\n\tBest Results from the Same Model (Epoch = %3d):' % (best_epoch + 1 + args.resp_val_start_epoch)) for index, clf_target in enumerate(clf_targets): print('\t\t%-24s Accuracy: %.3f%%' % (clf_target, val_cl_clf_acc[best_epoch, index])) # Print best predictions for drug classifiers and regressor print('\t\tDrug Target Family \t Accuracy: %.3f%% ' % (val_drug_target_acc[best_epoch])) print('\t\tDrug Weighted QED \t R2 Score: %+6.4f ' '(MSE = %8.6f, MAE = %6.6f)' % (val_drug_qed_r2[best_epoch], val_drug_qed_mse[best_epoch], val_drug_qed_mae[best_epoch])) for index, data_source in enumerate(val_data_sources): print( '\t\t%-6s \t R2 Score: %+6.4f ' '(MSE = %8.2f, MAE = %6.2f)' % (data_source, val_resp_r2[best_epoch, index], val_resp_mse[best_epoch, index], val_resp_mae[best_epoch, index]))
def __init__( self, data_root: str, training: bool, rand_state: int = 0, summary: bool = True, # Data type settings (for storage and data loading) int_dtype: type = np.int8, float_dtype: type = np.float16, output_dtype: type = np.float32, # Pre-processing settings rnaseq_scaling: str = 'std', predict_target: str = 'class', # Partitioning (train/validation) and data usage settings rnaseq_feature_usage: str = 'source_scale', validation_ratio: float = 0.2, ): """dataset = CLClassDataset('./data/', True) Construct a RNA sequence dataset based on the parameters provided. The process includes: * Downloading source data files; * Pre-processing (scaling); * Public attributes and other preparations. Args: data_root (str): path to data root folder. training (bool): indicator for training. rand_state (int): random seed used for training/validation split and other processes that requires randomness. summary (bool): set True for printing dataset summary. int_dtype (type): integer dtype for data storage in RAM. float_dtype (type): float dtype for data storage in RAM. output_dtype (type): output dtype for neural network. rnaseq_scaling (str): scaling method for RNA sequence. Choose between 'none', 'std', and 'minmax'. predict_target (str): prediction target for RNA sequence. Note that any labels except for target will be in one-hot encoding, while the target will be encoded as integers. Choose between 'none', 'class', and 'source'. rnaseq_feature_usage: RNA sequence data usage. Choose between 'source_scale' and 'combat'. validation_ratio (float): portion of validation data out of all data samples. """ # Initialization ###################################################### self.__data_root = data_root # Class-wise variables self.training = training self.__rand_state = rand_state self.__output_dtype = output_dtype # Feature scaling if rnaseq_scaling is None or rnaseq_scaling == '': rnaseq_scaling = 'none' self.__rnaseq_scaling = rnaseq_scaling.lower() if predict_target is None or predict_target == '': predict_target = 'none' assert predict_target.lower() in ['none', 'class', 'source'] self.__predict_target = predict_target.lower() self.__rnaseq_feature_usage = rnaseq_feature_usage self.__validation_ratio = validation_ratio # Load all dataframes ################################################# self.__rnaseq_df = get_rna_seq_df( data_root=data_root, rnaseq_feature_usage=rnaseq_feature_usage, rnaseq_scaling=rnaseq_scaling, float_dtype=float_dtype) self.__cl_meta_df = get_cl_meta_df(data_root=data_root, int_dtype=int_dtype) # Put all the sequence in one column as list and specify dtype self.__rnaseq_df['seq'] = \ list(map(float_dtype, self.__rnaseq_df.values.tolist())) # Join the RNA sequence data with meta data. cl_df will have columns: # ['data_src', 'site', 'type', 'category', 'seq'] self.__cl_df = pd.concat( [self.__cl_meta_df, self.__rnaseq_df[['seq']]], axis=1, join='inner') # Exclude 'GDC' and 'NCI60' during data source prediction # GDC has too many samples while NCI60 has not enough if self.__predict_target == 'source': logger.warning('Taking out GDC and NCI60 samples to make dataset ' 'balanced among all data sources ...') self.__cl_df = self.__cl_df[~self.__cl_df['data_src'].isin([2, 5])] # Encode labels (except for prediction targets) into one-hot encoding if self.__predict_target != 'source': enc_data_src = encode_int_to_onehot( self.__cl_df['data_src'].tolist(), len(get_label_dict(data_root, 'data_src_dict.txt'))) self.__cl_df['data_src'] = list(map(int_dtype, enc_data_src)) if self.__predict_target != 'class': for label in ['site', 'type', 'category']: enc_label = encode_int_to_onehot( self.__cl_df[label].tolist(), len(get_label_dict(data_root, '%s_dict.txt' % label))) self.__cl_df[label] = list(map(int_dtype, enc_label)) # Train/validation split ############################################## self.__split_drug_resp() # Converting dataframes to arrays for rapid access #################### self.__cl_array = self.__cl_df.values # Public attributes ################################################### self.cells = self.__cl_df.index.tolist() self.num_cells = self.__cl_df.shape[0] self.rnaseq_dim = len(self.__cl_df.iloc[0]['seq']) # Clear the dataframes ################################################ self.__rnaseq_df = None self.__cl_meta_df = None self.__cl_df = None # Dataset summary ##################################################### if summary: print('=' * 80) print(('Training' if self.training else 'Validation') + ' RNA Sequence Dataset Summary:') print('\t%i Unique Cell Lines (feature dim: %4i).' % (self.num_cells, self.rnaseq_dim)) print('=' * 80)
def main(): # Training settings and hyper-parameters parser = argparse.ArgumentParser( description='Data Source (Batch) Prediction for Cell Lines') # Dataset parameters ###################################################### # Pre-processing for dataframes parser.add_argument('--rnaseq_scaling', type=str, default='std', help='scaling method for RNA sequence', choices=SCALING_METHODS) # Feature usage and partitioning settings parser.add_argument('--rnaseq_feature_usage', type=str, default='combat', help='RNA sequence data used', choices=[ 'source_scale', 'combat', ]) parser.add_argument('--validation_ratio', type=float, default=0.2, help='ratio for validation dataset') # Network configuration ################################################### parser.add_argument('--layer_dim', type=int, default=256, help='dimension of layers for RNA sequence') parser.add_argument('--num_layers', type=int, default=4, help='number of layers for RNA sequence') # Training and validation parameters ###################################### parser.add_argument('--opt', type=str, default='SGD', help='optimizer for data source prediction', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--lr', type=float, default=1e-2, help='learning rate for data source prediction') # Starting epoch for validation parser.add_argument('--val_start_epoch', type=int, default=0, help='starting epoch for data source prediction') # Early stopping based on data source prediction accuracy parser.add_argument('--early_stop_patience', type=int, default=50, help='patience for early stopping based on data ' 'source prediction accuracy') # Global/shared training parameters parser.add_argument('--l2_regularization', type=float, default=0., help='L2 regularization for nn weights') parser.add_argument('--lr_decay_factor', type=float, default=0.98, help='decay factor for learning rate') parser.add_argument('--trn_batch_size', type=int, default=32, help='input batch size for training') parser.add_argument('--val_batch_size', type=int, default=256, help='input batch size for validation') parser.add_argument('--max_num_batches', type=int, default=10000, help='maximum number of batches per epoch') parser.add_argument('--max_num_epochs', type=int, default=1000, help='maximum number of epochs') # Miscellaneous settings ################################################## parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--rand_state', type=int, default=0, help='random state of numpy/sklearn/pytorch') args = parser.parse_args() print('Training Arguments:\n' + json.dumps(vars(args), indent=4)) # Setting up random seed for reproducible and deterministic results seed_random_state(args.rand_state) # Computation device config (cuda or cpu) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') # Data loaders for training/validation #################################### dataloader_kwargs = { 'timeout': 1, 'shuffle': 'True', # 'num_workers': multiprocessing.cpu_count() if use_cuda else 0, 'num_workers': NUM_WORKER if use_cuda else 0, 'pin_memory': True if use_cuda else False, } # Drug response dataloaders for training/validation cl_clf_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'rnaseq_scaling': args.rnaseq_scaling, 'predict_target': 'source', 'rnaseq_feature_usage': args.rnaseq_feature_usage, 'validation_ratio': args.validation_ratio, } cl_clf_trn_loader = torch.utils.data.DataLoader( CLClassDataset(training=True, **cl_clf_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) cl_clf_val_loader = torch.utils.data.DataLoader( CLClassDataset(training=False, **cl_clf_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) # Constructing and initializing neural networks ########################### net = nn.Sequential() prev_dim = cl_clf_trn_loader.dataset.rnaseq_dim for label in ['site', 'type', 'category']: prev_dim += len(get_label_dict(DATA_ROOT, '%s_dict.txt' % label)) # net.add_module('dense_%d' % 0, nn.Linear(prev_dim, args.layer_dim)) for i in range(args.num_layers): # net.add_module('residual_block_%d' % i, # ResBlock(layer_dim=args.layer_dim, # num_layers=2, # dropout=0.)) net.add_module('dense_%d' % i, nn.Linear(prev_dim, args.layer_dim)) net.add_module('dropout_%d' % i, nn.Dropout(0.2)) prev_dim = args.layer_dim net.add_module('relu_%d' % i, nn.ReLU()) num_data_src = len(get_label_dict(DATA_ROOT, 'data_src_dict.txt')) net.add_module('dense', nn.Linear(args.layer_dim, num_data_src)) net.add_module('logsoftmax', nn.LogSoftmax(dim=1)) net.apply(basic_weight_init) net.to(device) print(net) # Optimizers, learning rate decay, and miscellaneous ###################### opt = get_optimizer(opt_type=args.opt, networks=net, learning_rate=args.lr, l2_regularization=args.l2_regularization) lr_decay = LambdaLR(optimizer=opt, lr_lambda=lambda e: args.lr_decay_factor**e) # Training/validation loops ############################################### val_acc = [] best_acc = 0. patience = 0 start_time = time.time() for epoch in range(args.max_num_epochs): print('=' * 80 + '\nTraining Epoch %3i:' % (epoch + 1)) epoch_start_time = time.time() lr_decay.step(epoch) # Training loop ####################################################### net.train() for batch_idx, (rnaseq, data_src, cl_site, cl_type, cl_category) \ in enumerate(cl_clf_trn_loader): if batch_idx >= args.max_num_batches: break rnaseq, data_src, cl_site, cl_type, cl_category = \ rnaseq.to(device), data_src.to(device), cl_site.to(device), \ cl_type.to(device), cl_category.to(device) net.zero_grad() out_data_src = net( torch.cat((rnaseq, cl_site, cl_type, cl_category), dim=1)) F.nll_loss(input=out_data_src, target=data_src).backward() opt.step() # Validation loop ##################################################### net.eval() correct_data_src = 0 with torch.no_grad(): for rnaseq, data_src, cl_site, cl_type, cl_category \ in cl_clf_val_loader: rnaseq, data_src, cl_site, cl_type, cl_category = \ rnaseq.to(device), data_src.to(device), \ cl_site.to(device), cl_type.to(device), \ cl_category.to(device) out_data_src = net( torch.cat((rnaseq, cl_site, cl_type, cl_category), dim=1)) pred_data_src = out_data_src.max(1, keepdim=True)[1] # print(data_src) # print(pred_data_src) correct_data_src += pred_data_src.eq( data_src.view_as(pred_data_src)).sum().item() data_src_acc = 100. * correct_data_src / len(cl_clf_val_loader.dataset) print( '\tCell Line Data Source (Batch) Prediction Accuracy: %5.2f%%; ' % data_src_acc) # Results recording and early stopping val_acc.append(data_src_acc) if data_src_acc > best_acc: patience = 0 best_acc = data_src_acc else: patience += 1 if patience >= args.early_stop_patience: print('Validation accuracy does not improve for %d epochs ... ' 'invoking early stopping.' % patience) break print('Epoch Running Time: %.1f Seconds.' % (time.time() - epoch_start_time)) print('Program Running Time: %.1f Seconds.' % (time.time() - start_time)) print('Best Cell Line Data Source (Batch) Prediction Accuracy: %5.2f%%; ' % np.amax(val_acc))
def build_nn(self): args = self.args device = self.device # Constructing and initializing neural networks ########################### # Autoencoder training hyper-parameters self.ae_training_kwarg = { 'ae_loss_func': 'mse', 'ae_opt': 'sgd', 'ae_lr': 2e-1, 'ae_reg': 1e-5, 'lr_decay_factor': 1.0, 'max_num_epochs': 5, 'early_stop_patience': 50, } self.encoder_kwarg = { 'model_folder': './models/', 'data_root': DATA_ROOT, 'autoencoder_init': args.autoencoder_init, 'training_kwarg': self.ae_training_kwarg, 'device': device, 'verbose': True, 'rand_state': args.rng_seed, } # Get RNA sequence encoder self.gene_encoder = get_gene_encoder( rnaseq_feature_usage=args.rnaseq_feature_usage, rnaseq_scaling=args.rnaseq_scaling, layer_dim=args.gene_layer_dim, num_layers=args.gene_num_layers, latent_dim=args.gene_latent_dim, **(self.encoder_kwarg)) # Get drug feature encoder self.drug_encoder = get_drug_encoder( drug_feature_usage=args.drug_feature_usage, dscptr_scaling=args.dscptr_scaling, dscptr_nan_threshold=args.dscptr_nan_threshold, layer_dim=args.drug_layer_dim, num_layers=args.drug_num_layers, latent_dim=args.drug_latent_dim, **(self.encoder_kwarg)) # Regressor for drug response self.resp_net = RespNet( gene_latent_dim=args.gene_latent_dim, drug_latent_dim=args.drug_latent_dim, gene_encoder=self.gene_encoder, drug_encoder=self.drug_encoder, resp_layer_dim=args.resp_layer_dim, resp_num_layers_per_block=args.resp_num_layers_per_block, resp_num_blocks=args.resp_num_blocks, resp_num_layers=args.resp_num_layers, resp_dropout=args.dropout, resp_activation=args.resp_activation).to(device) print(self.resp_net) # Sequence classifier for category, site, and type self.cl_clf_net_kwargs = { 'encoder': self.gene_encoder, 'input_dim': args.gene_latent_dim, 'condition_dim': len(get_label_dict(DATA_ROOT, 'data_src_dict.txt')), 'layer_dim': args.cl_clf_layer_dim, 'num_layers': args.cl_clf_num_layers, } self.category_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'category_dict.txt')), **(self.cl_clf_net_kwargs)).to(device) self.site_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'site_dict.txt')), **(self.cl_clf_net_kwargs)).to(device) self.type_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'type_dict.txt')), **(self.cl_clf_net_kwargs)).to(device) # Classifier for drug target family prediction self.drug_target_net = ClfNet( encoder=self.drug_encoder, input_dim=args.drug_latent_dim, condition_dim=0, layer_dim=args.drug_target_layer_dim, num_layers=args.drug_target_num_layers, num_classes=len(get_label_dict(DATA_ROOT, 'drug_target_dict.txt'))).\ to(device) # Regressor for drug weighted QED prediction self.drug_qed_net = RgsNet( encoder=self.drug_encoder, input_dim=args.drug_latent_dim, condition_dim=0, layer_dim=args.drug_qed_layer_dim, num_layers=args.drug_qed_num_layers, activation=args.drug_qed_activation).to(device) # Multi-GPU settings if args.multi_gpu: resp_net = nn.DataParallel(self.resp_net) category_clf_net = nn.DataParallel(self.category_clf_net) site_clf_net = nn.DataParallel(self.site_clf_net) type_clf_net = nn.DataParallel(self.type_clf_net) drug_target_net = nn.DataParallel(self.drug_target_net) drug_qed_net = nn.DataParallel(self.drug_qed_net)