def main(): # Training settings and hyper-parameters parser = argparse.ArgumentParser( description='Multitasking Neural Network for Genes and Drugs') # Dataset parameters ###################################################### # Training and validation data sources parser.add_argument('--trn_src', type=str, required=True, help='training source for drug response') parser.add_argument('--val_srcs', type=str, required=True, nargs='+', help='validation list of sources for drug response') # Pre-processing for dataframes parser.add_argument('--grth_scaling', type=str, default='std', help='scaling method for drug response (growth)', choices=SCALING_METHODS) parser.add_argument('--dscptr_scaling', type=str, default='std', help='scaling method for drug feature (descriptor)', choices=SCALING_METHODS) parser.add_argument('--rnaseq_scaling', type=str, default='std', help='scaling method for RNA sequence', choices=SCALING_METHODS) parser.add_argument('--dscptr_nan_threshold', type=float, default=0.0, help='ratio of NaN values allowed for drug descriptor') parser.add_argument('--qed_scaling', type=str, default='none', help='scaling method for drug weighted QED', choices=SCALING_METHODS) # Feature usage and partitioning settings parser.add_argument('--rnaseq_feature_usage', type=str, default='combat', help='RNA sequence data used', choices=[ 'source_scale', 'combat', ]) parser.add_argument('--drug_feature_usage', type=str, default='both', help='drug features (fp and/or desc) used', choices=[ 'fingerprint', 'descriptor', 'both', ]) parser.add_argument('--validation_ratio', type=float, default=0.2, help='ratio for validation dataset') parser.add_argument('--disjoint_drugs', action='store_true', help='disjoint drugs between train/validation') parser.add_argument('--disjoint_cells', action='store_true', help='disjoint cells between train/validation') # Network configuration ################################################### # Encoders for drug features and RNA sequence (LINCS 1000) parser.add_argument('--gene_layer_dim', type=int, default=1024, help='dimension of layers for RNA sequence') parser.add_argument('--gene_latent_dim', type=int, default=256, help='dimension of latent variable for RNA sequence') parser.add_argument('--gene_num_layers', type=int, default=2, help='number of layers for RNA sequence') parser.add_argument('--drug_layer_dim', type=int, default=4096, help='dimension of layers for drug feature') parser.add_argument('--drug_latent_dim', type=int, default=1024, help='dimension of latent variable for drug feature') parser.add_argument('--drug_num_layers', type=int, default=2, help='number of layers for drug feature') # Using autoencoder for drug/sequence encoder initialization parser.add_argument('--autoencoder_init', action='store_true', help='indicator of autoencoder initialization for ' 'drug/RNA sequence feature encoder') # Drug response regression network parser.add_argument('--resp_layer_dim', type=int, default=1024, help='dimension of layers for drug response block') parser.add_argument('--resp_num_layers_per_block', type=int, default=2, help='number of layers for drug response res block') parser.add_argument('--resp_num_blocks', type=int, default=2, help='number of residual blocks for drug response') parser.add_argument('--resp_num_layers', type=int, default=2, help='number of layers for drug response') parser.add_argument('--resp_dropout', type=float, default=0.0, help='dropout of residual blocks for drug response') parser.add_argument('--resp_activation', type=str, default='none', help='activation for response prediction output', choices=['sigmoid', 'tanh', 'none']) # Cell line classification network(s) parser.add_argument('--cl_clf_layer_dim', type=int, default=256, help='layer dimension for cell line classification') parser.add_argument('--cl_clf_num_layers', type=int, default=1, help='number of layers for cell line classification') # Drug target family classification network parser.add_argument('--drug_target_layer_dim', type=int, default=512, help='dimension of layers for drug target prediction') parser.add_argument('--drug_target_num_layers', type=int, default=2, help='number of layers for drug target prediction') # Drug weighted QED regression network parser.add_argument('--drug_qed_layer_dim', type=int, default=512, help='dimension of layers for drug QED prediction') parser.add_argument('--drug_qed_num_layers', type=int, default=2, help='number of layers for drug QED prediction') parser.add_argument('--drug_qed_activation', type=str, default='none', help='activation for drug QED prediction output', choices=['sigmoid', 'tanh', 'none']) # Training and validation parameters ###################################### # Drug response regression training parameters parser.add_argument('--resp_loss_func', type=str, default='mse', help='loss function for drug response regression', choices=['mse', 'l1']) parser.add_argument('--resp_opt', type=str, default='SGD', help='optimizer for drug response regression', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--resp_lr', type=float, default=1e-5, help='learning rate for drug response regression') # Drug response uncertainty quantification parameters parser.add_argument('--resp_uq', action='store_true', help='indicator of drug response uncertainty ' 'quantification using dropouts') parser.add_argument('--resp_uq_dropout', type=float, default=0.5, help='dropout rate for uncertainty quantification') parser.add_argument('--resp_uq_length_scale', type=float, default=1.0, help='Prior length-scale that captures our belief ' 'over the function frequency') parser.add_argument('--resp_uq_num_runs', type=int, default=100, help='number of predictions (runs) for uncertainty ' 'quantification') # Cell line classification training parameters parser.add_argument('--cl_clf_opt', type=str, default='SGD', help='optimizer for cell line classification', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--cl_clf_lr', type=float, default=1e-3, help='learning rate for cell line classification') # Drug target family classification training parameters parser.add_argument('--drug_target_opt', type=str, default='SGD', help='optimizer for drug target classification ' 'training', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--drug_target_lr', type=float, default=1e-3, help='learning rate for drug target classification') # Drug weighted QED regression training parameters parser.add_argument('--drug_qed_loss_func', type=str, default='mse', help='loss function for drug QED regression', choices=['mse', 'l1']) parser.add_argument('--drug_qed_opt', type=str, default='SGD', help='optimizer for drug rQED regression', choices=['SGD', 'RMSprop', 'Adam']) parser.add_argument('--drug_qed_lr', type=float, default=1e-3, help='learning rate for drug QED regression') # Starting epoch for drug response validation parser.add_argument('--resp_val_start_epoch', type=int, default=0, help='starting epoch for drug response validation') # Early stopping based on R2 score of drug response prediction parser.add_argument('--early_stop_patience', type=int, default=5, help='patience for early stopping based on drug ' 'response validation R2 scores ') # Global/shared training parameters parser.add_argument('--l2_regularization', type=float, default=1e-5, help='L2 regularization for nn weights') parser.add_argument('--lr_decay_factor', type=float, default=0.95, help='decay factor for learning rate') parser.add_argument('--trn_batch_size', type=int, default=32, help='input batch size for training') parser.add_argument('--val_batch_size', type=int, default=256, help='input batch size for validation') parser.add_argument('--max_num_batches', type=int, default=1000, help='maximum number of batches per epoch') parser.add_argument('--max_num_epochs', type=int, default=100, help='maximum number of epochs') # Validation results directory parser.add_argument('--val_results_dir', type=str, default=None, help='directory for saved validation results. ' 'Set to None to skip results saving') # Miscellaneous settings ################################################## parser.add_argument('--multi_gpu', action='store_true', default=False, help='enables multiple GPU process') parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--rand_state', type=int, default=0, help='random state of numpy/sklearn/pytorch') args = parser.parse_args() print('Training Arguments:\n' + json.dumps(vars(args), indent=4)) # Setting up random seed for reproducible and deterministic results seed_random_state(args.rand_state) # Computation device config (cuda or cpu) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') # Data loaders for training/validation #################################### dataloader_kwargs = { 'timeout': 1, 'shuffle': 'True', # 'num_workers': multiprocessing.cpu_count() if use_cuda else 0, 'num_workers': NUM_WORKER if use_cuda else 0, 'pin_memory': True if use_cuda else False, } # Drug response dataloaders for training/validation drug_resp_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'grth_scaling': args.grth_scaling, 'dscptr_scaling': args.dscptr_scaling, 'rnaseq_scaling': args.rnaseq_scaling, 'dscptr_nan_threshold': args.dscptr_nan_threshold, 'rnaseq_feature_usage': args.rnaseq_feature_usage, 'drug_feature_usage': args.drug_feature_usage, 'validation_ratio': args.validation_ratio, 'disjoint_drugs': args.disjoint_drugs, 'disjoint_cells': args.disjoint_cells, } drug_resp_trn_loader = torch.utils.data.DataLoader( DrugRespDataset(data_src=args.trn_src, training=True, **drug_resp_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) # List of data loaders for different validation sets drug_resp_val_loaders = [ torch.utils.data.DataLoader( DrugRespDataset(data_src=src, training=False, **drug_resp_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) for src in args.val_srcs ] # Cell line classification dataloaders for training/validation cl_clf_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'rnaseq_scaling': args.rnaseq_scaling, 'rnaseq_feature_usage': args.rnaseq_feature_usage, 'validation_ratio': args.validation_ratio, } cl_clf_trn_loader = torch.utils.data.DataLoader( CLClassDataset(training=True, **cl_clf_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) cl_clf_val_loader = torch.utils.data.DataLoader( CLClassDataset(training=False, **cl_clf_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) # Drug target family classification dataloaders for training/validation drug_target_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'dscptr_scaling': args.dscptr_scaling, 'dscptr_nan_threshold': args.dscptr_nan_threshold, 'drug_feature_usage': args.drug_feature_usage, 'validation_ratio': args.validation_ratio, } drug_target_trn_loader = torch.utils.data.DataLoader( DrugTargetDataset(training=True, **drug_target_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) drug_target_val_loader = torch.utils.data.DataLoader( DrugTargetDataset(training=False, **drug_target_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) # Drug weighted QED regression dataloaders for training/validation drug_qed_dataset_kwargs = { 'data_root': DATA_ROOT, 'rand_state': args.rand_state, 'summary': False, 'int_dtype': np.int8, 'float_dtype': np.float16, 'output_dtype': np.float32, 'qed_scaling': args.qed_scaling, 'dscptr_scaling': args.dscptr_scaling, 'dscptr_nan_threshold': args.dscptr_nan_threshold, 'drug_feature_usage': args.drug_feature_usage, 'validation_ratio': args.validation_ratio, } drug_qed_trn_loader = torch.utils.data.DataLoader( DrugQEDDataset(training=True, **drug_qed_dataset_kwargs), batch_size=args.trn_batch_size, **dataloader_kwargs) drug_qed_val_loader = torch.utils.data.DataLoader( DrugQEDDataset(training=False, **drug_qed_dataset_kwargs), batch_size=args.val_batch_size, **dataloader_kwargs) # Constructing and initializing neural networks ########################### # Autoencoder training hyper-parameters ae_training_kwarg = { 'ae_loss_func': 'mse', 'ae_opt': 'sgd', 'ae_lr': 2e-1, 'lr_decay_factor': 1.0, 'max_num_epochs': 1000, 'early_stop_patience': 50, } encoder_kwarg = { 'model_folder': './models/', 'data_root': DATA_ROOT, 'autoencoder_init': args.autoencoder_init, 'training_kwarg': ae_training_kwarg, 'device': device, 'verbose': True, 'rand_state': args.rand_state, } # Get RNA sequence encoder gene_encoder = get_gene_encoder( rnaseq_feature_usage=args.rnaseq_feature_usage, rnaseq_scaling=args.rnaseq_scaling, layer_dim=args.gene_layer_dim, num_layers=args.gene_num_layers, latent_dim=args.gene_latent_dim, **encoder_kwarg) # Get drug feature encoder drug_encoder = get_drug_encoder( drug_feature_usage=args.drug_feature_usage, dscptr_scaling=args.dscptr_scaling, dscptr_nan_threshold=args.dscptr_nan_threshold, layer_dim=args.drug_layer_dim, num_layers=args.drug_num_layers, latent_dim=args.drug_latent_dim, **encoder_kwarg) # Regressor for drug response resp_net = RespNet( gene_latent_dim=args.gene_latent_dim, drug_latent_dim=args.drug_latent_dim, gene_encoder=gene_encoder, drug_encoder=drug_encoder, resp_layer_dim=args.resp_layer_dim, resp_num_layers_per_block=args.resp_num_layers_per_block, resp_num_blocks=args.resp_num_blocks, resp_num_layers=args.resp_num_layers, resp_dropout=args.resp_dropout, resp_activation=args.resp_activation).to(device) print(resp_net) # Sequence classifier for category, site, and type cl_clf_net_kwargs = { 'encoder': gene_encoder, 'input_dim': args.gene_latent_dim, 'condition_dim': len(get_label_dict(DATA_ROOT, 'data_src_dict.txt')), 'layer_dim': args.cl_clf_layer_dim, 'num_layers': args.cl_clf_num_layers, } category_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'category_dict.txt')), **cl_clf_net_kwargs).to(device) site_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'site_dict.txt')), **cl_clf_net_kwargs).to(device) type_clf_net = ClfNet(num_classes=len( get_label_dict(DATA_ROOT, 'type_dict.txt')), **cl_clf_net_kwargs).to(device) # Classifier for drug target family prediction drug_target_net = ClfNet( encoder=drug_encoder, input_dim=args.drug_latent_dim, condition_dim=0, layer_dim=args.drug_target_layer_dim, num_layers=args.drug_target_num_layers, num_classes=len(get_label_dict(DATA_ROOT, 'drug_target_dict.txt'))).\ to(device) # Regressor for drug weighted QED prediction drug_qed_net = RgsNet(encoder=drug_encoder, input_dim=args.drug_latent_dim, condition_dim=0, layer_dim=args.drug_qed_layer_dim, num_layers=args.drug_qed_num_layers, activation=args.drug_qed_activation).to(device) # Multi-GPU settings if args.multi_gpu: resp_net = nn.DataParallel(resp_net) category_clf_net = nn.DataParallel(category_clf_net) site_clf_net = nn.DataParallel(site_clf_net) type_clf_net = nn.DataParallel(type_clf_net) drug_target_net = nn.DataParallel(drug_target_net) drug_qed_net = nn.DataParallel(drug_qed_net) # Optimizers, learning rate decay, and miscellaneous ###################### resp_opt = get_optimizer(opt_type=args.resp_opt, networks=resp_net, learning_rate=args.resp_lr, l2_regularization=args.l2_regularization) cl_clf_opt = get_optimizer( opt_type=args.cl_clf_opt, networks=[category_clf_net, site_clf_net, type_clf_net], learning_rate=args.cl_clf_lr, l2_regularization=args.l2_regularization) drug_target_opt = get_optimizer(opt_type=args.drug_target_opt, networks=drug_target_net, learning_rate=args.drug_target_lr, l2_regularization=args.l2_regularization) drug_qed_opt = get_optimizer(opt_type=args.drug_qed_opt, networks=drug_qed_net, learning_rate=args.drug_qed_lr, l2_regularization=args.l2_regularization) resp_lr_decay = LambdaLR(optimizer=resp_opt, lr_lambda=lambda e: args.lr_decay_factor**e) cl_clf_lr_decay = LambdaLR(optimizer=cl_clf_opt, lr_lambda=lambda e: args.lr_decay_factor**e) drug_target_lr_decay = LambdaLR( optimizer=drug_target_opt, lr_lambda=lambda e: args.lr_decay_factor**e) drug_qed_lr_decay = LambdaLR(optimizer=drug_qed_opt, lr_lambda=lambda e: args.lr_decay_factor**e) resp_loss_func = F.l1_loss if args.resp_loss_func == 'l1' \ else F.mse_loss drug_qed_loss_func = F.l1_loss if args.drug_qed_loss_func == 'l1' \ else F.mse_loss # Training/validation loops ############################################### val_cl_clf_acc = [] val_drug_target_acc = [] val_drug_qed_mse, val_drug_qed_mae, val_drug_qed_r2 = [], [], [] val_resp_mse, val_resp_mae, val_resp_r2 = [], [], [] best_r2 = -np.inf patience = 0 start_time = time.time() # Create folder for validation results if not exist if args.val_results_dir.lower() != 'none': try: os.makedirs(args.val_results_dir) except OSError as e: if e.errno != errno.EEXIST: raise else: args.val_results_dir = None # Early stopping is decided on the validation set with the same # data source as the training dataloader val_index = 0 for idx, loader in enumerate(drug_resp_val_loaders): if loader.dataset.data_source == args.trn_src: val_index = idx for epoch in range(args.max_num_epochs): print('=' * 80 + '\nTraining Epoch %3i:' % (epoch + 1)) epoch_start_time = time.time() resp_lr_decay.step(epoch) cl_clf_lr_decay.step(epoch) drug_target_lr_decay.step(epoch) drug_qed_lr_decay.step(epoch) # Training cell line classifier train_cl_clf(device=device, category_clf_net=category_clf_net, site_clf_net=site_clf_net, type_clf_net=type_clf_net, data_loader=cl_clf_trn_loader, max_num_batches=args.max_num_batches, optimizer=cl_clf_opt) # Training drug target classifier train_drug_target(device=device, drug_target_net=drug_target_net, data_loader=drug_target_trn_loader, max_num_batches=args.max_num_batches, optimizer=drug_target_opt) # Training drug weighted QED regressor train_drug_qed(device=device, drug_qed_net=drug_qed_net, data_loader=drug_qed_trn_loader, max_num_batches=args.max_num_batches, loss_func=drug_qed_loss_func, optimizer=drug_qed_opt) # Training drug response regressor train_resp(device=device, resp_net=resp_net, data_loader=drug_resp_trn_loader, max_num_batches=args.max_num_batches, loss_func=resp_loss_func, optimizer=resp_opt) print('\nValidation Results:') if epoch >= args.resp_val_start_epoch: # Validating cell line classifier cl_category_acc, cl_site_acc, cl_type_acc = \ valid_cl_clf(device=device, category_clf_net=category_clf_net, site_clf_net=site_clf_net, type_clf_net=type_clf_net, data_loader=cl_clf_val_loader, ) val_cl_clf_acc.append([cl_category_acc, cl_site_acc, cl_type_acc]) # Validating drug target classifier drug_target_acc = \ valid_drug_target(device=device, drug_target_net=drug_target_net, data_loader=drug_target_val_loader) val_drug_target_acc.append(drug_target_acc) # Validating drug weighted QED regressor drug_qed_mse, drug_qed_mae, drug_qed_r2 = \ valid_drug_qed(device=device, drug_qed_net=drug_qed_net, data_loader=drug_qed_val_loader) val_drug_qed_mse.append(drug_qed_mse) val_drug_qed_mae.append(drug_qed_mae) val_drug_qed_r2.append(drug_qed_r2) # Validating drug response regressor resp_mse, resp_mae, resp_r2 = \ valid_resp(epoch=epoch, trn_src=args.trn_src, device=device, resp_net=resp_net, data_loaders=drug_resp_val_loaders, resp_uq=args.resp_uq, resp_uq_dropout=args.resp_uq_dropout, resp_uq_num_runs=args.resp_uq_num_runs, val_results_dir=args.val_results_dir) # Save the validation results in nested list val_resp_mse.append(resp_mse) val_resp_mae.append(resp_mae) val_resp_r2.append(resp_r2) # Record the best R2 score (same data source) # and check for early stopping if no improvement for epochs if resp_r2[val_index] > best_r2: patience = 0 best_r2 = resp_r2[val_index] else: patience += 1 if patience >= args.early_stop_patience: print('Validation results does not improve for %d epochs ... ' 'invoking early stopping.' % patience) break print('Epoch Running Time: %.1f Seconds.' % (time.time() - epoch_start_time)) val_cl_clf_acc = np.array(val_cl_clf_acc).reshape(-1, 3) # val_drug_target_acc = np.array(val_drug_target_acc) # val_drug_qed_mse = np.array(val_drug_qed_mse) # val_resp_mae = np.array(val_resp_mae) # val_resp_r2 = np.array(val_resp_r2) val_resp_mse, val_resp_mae, val_resp_r2 = \ np.array(val_resp_mse).reshape(-1, len(args.val_srcs)), \ np.array(val_resp_mae).reshape(-1, len(args.val_srcs)), \ np.array(val_resp_r2).reshape(-1, len(args.val_srcs)) print('Program Running Time: %.1f Seconds.' % (time.time() - start_time)) # Print overall validation results print('=' * 80) print('Overall Validation Results:\n') print('\tBest Results from Different Models (Epochs):') # Print best accuracy for cell line classifiers clf_targets = [ 'Cell Line Categories', 'Cell Line Sites', 'Cell Line Types', ] best_acc = np.amax(val_cl_clf_acc, axis=0) best_acc_epochs = np.argmax(val_cl_clf_acc, axis=0) for index, clf_target in enumerate(clf_targets): print('\t\t%-24s Best Accuracy: %.3f%% (Epoch = %3d)' % (clf_target, best_acc[index], best_acc_epochs[index] + 1 + args.resp_val_start_epoch)) # Print best predictions for drug classifiers and regressor print('\t\tDrug Target Family \t Best Accuracy: %.3f%% (Epoch = %3d)' % (np.max(val_drug_target_acc), (np.argmax(val_drug_target_acc) + 1 + args.resp_val_start_epoch))) print('\t\tDrug Weighted QED \t Best R2 Score: %+6.4f ' '(Epoch = %3d, MSE = %8.6f, MAE = %8.6f)' % (np.max(val_drug_qed_r2), (np.argmax(val_drug_qed_r2) + 1 + args.resp_val_start_epoch), val_drug_qed_mse[np.argmax(val_drug_qed_r2)], val_drug_qed_mae[np.argmax(val_drug_qed_r2)])) # Print best R2 scores for drug response regressor val_data_sources = \ [loader.dataset.data_source for loader in drug_resp_val_loaders] best_r2 = np.amax(val_resp_r2, axis=0) best_r2_epochs = np.argmax(val_resp_r2, axis=0) for index, data_source in enumerate(val_data_sources): print('\t\t%-6s \t Best R2 Score: %+6.4f ' '(Epoch = %3d, MSE = %8.2f, MAE = %6.2f)' % (data_source, best_r2[index], best_r2_epochs[index] + args.resp_val_start_epoch + 1, val_resp_mse[best_r2_epochs[index], index], val_resp_mae[best_r2_epochs[index], index])) # Print best epoch and all the corresponding validation results # Picking the best epoch using R2 score from same data source best_epoch = val_resp_r2[:, val_index].argmax() print('\n\tBest Results from the Same Model (Epoch = %3d):' % (best_epoch + 1 + args.resp_val_start_epoch)) for index, clf_target in enumerate(clf_targets): print('\t\t%-24s Accuracy: %.3f%%' % (clf_target, val_cl_clf_acc[best_epoch, index])) # Print best predictions for drug classifiers and regressor print('\t\tDrug Target Family \t Accuracy: %.3f%% ' % (val_drug_target_acc[best_epoch])) print('\t\tDrug Weighted QED \t R2 Score: %+6.4f ' '(MSE = %8.6f, MAE = %6.6f)' % (val_drug_qed_r2[best_epoch], val_drug_qed_mse[best_epoch], val_drug_qed_mae[best_epoch])) for index, data_source in enumerate(val_data_sources): print( '\t\t%-6s \t R2 Score: %+6.4f ' '(MSE = %8.2f, MAE = %6.2f)' % (data_source, val_resp_r2[best_epoch, index], val_resp_mse[best_epoch, index], val_resp_mae[best_epoch, index]))
def train(self): args = self.args device = self.device # Training/validation loops ############################################### self.val_cl_clf_acc = [] self.val_drug_target_acc = [] self.val_drug_qed_mse = [] self.val_drug_qed_mae = [] self.val_drug_qed_r2 = [] self.val_resp_mse = [] self.val_resp_mae = [] self.val_resp_r2 = [] self.best_r2 = -np.inf self.patience = 0 self.start_time = time.time() for epoch in range(args.epochs): #print('=' * 80 + '\nTraining Epoch %3i:' % (epoch + 1)) print('=' * 80 + '\nTraining Epoch %3i/%3i:' % (epoch + 1, args.epochs)) epoch_start_time = time.time() self.resp_lr_decay.step(epoch) self.cl_clf_lr_decay.step(epoch) self.drug_target_lr_decay.step(epoch) self.drug_qed_lr_decay.step(epoch) # Training cell line classifier train_cl_clf(device=device, category_clf_net=self.category_clf_net, site_clf_net=self.site_clf_net, type_clf_net=self.type_clf_net, data_loader=self.cl_clf_trn_loader, max_num_batches=args.max_num_batches, optimizer=self.cl_clf_opt) # Training drug target classifier train_drug_target(device=device, drug_target_net=self.drug_target_net, data_loader=self.drug_target_trn_loader, max_num_batches=args.max_num_batches, optimizer=self.drug_target_opt) # Training drug weighted QED regressor train_drug_qed(device=device, drug_qed_net=self.drug_qed_net, data_loader=self.drug_qed_trn_loader, max_num_batches=args.max_num_batches, loss_func=self.drug_qed_loss_func, optimizer=self.drug_qed_opt) # Training drug response regressor train_resp(device=device, resp_net=self.resp_net, data_loader=self.drug_resp_trn_loader, max_num_batches=args.max_num_batches, loss_func=self.resp_loss_func, optimizer=self.resp_opt) if epoch >= args.resp_val_start_epoch: resp_r2 = self.validation(epoch) #print('\nValidation Results:') # Record the best R2 score (same data source) # and check for early stopping if no improvement for epochs if resp_r2[self.val_index] > self.best_r2: self.patience = 0 self.best_r2 = resp_r2[self.val_index] else: self.patience += 1 if self.patience >= args.early_stop_patience: print( 'Validation results does not improve for %d epochs ... ' 'invoking early stopping.' % self.patience) break print('Epoch Running Time: %.1f Seconds.' % (time.time() - epoch_start_time))