def train(model: torch.nn.Module, task_definition: TaskDefinition, early_stopping_target_id: str, trainingset_dataloader: torch.utils.data.DataLoader, trainingset_eval_dataloader: torch.utils.data.DataLoader, validationset_eval_dataloader: torch.utils.data.DataLoader, results_directory: str = "results", n_updates: int = int(1e5), show_progress: bool = True, load_file: str = None, device: torch.device = torch.device('cuda:0'), num_torch_threads: int = 3, learning_rate: float = 1e-4, l1_weight_decay: float = 0, l2_weight_decay: float = 0, log_training_stats_at: int = int(1e2), evaluate_at: int = int(5e3), ignore_missing_target_values: bool = True): """Train a DeepRC model on a given dataset on tasks specified in `task_definition` Model with lowest validation set loss on target `early_stopping_target_id` will be taken as final model (=early stopping). Model performance on validation set will be evaluated every `evaluate_at` updates. Trained model, logfile, and tensorboard files will be stored in `results_directory`. See `deeprc/examples/` for examples. Parameters ---------- model: torch.nn.Module deeprc.architectures.DeepRC or similar model as PyTorch module task_definition: TaskDefinition TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples. early_stopping_target_id: str ID of task in TaskDefinition object to use for early stopping. trainingset_dataloader: torch.utils.data.DataLoader Data loader for training trainingset_eval_dataloader: torch.utils.data.DataLoader Data loader for evaluation on training set (=no random subsampling) validationset_eval_dataloader: torch.utils.data.DataLoader Data loader for evaluation on validation set (=no random subsampling). Will be used for early-stopping. results_directory: str Directory to save checkpoint of best trained model, logfile, and tensorboard files in n_updates: int Number of updates to train for show_progress: bool Show progressbar? load_file: str Path to load checkpoint of previously saved model from device: torch.device Device to use for computations. E.g. `torch.device('cuda:0')` or `torch.device('cpu')`. Currently, only devices which support 16 bit float are supported. num_torch_threads: int Number of parallel threads to allow PyTorch learning_rate: float Learning rate for adam optimizer l1_weight_decay: float l1 weight decay factor. l1 weight penalty will be added to loss, scaled by `l1_weight_decay` l2_weight_decay: float l2 weight decay factor. l2 weight penalty will be added to loss, scaled by `l2_weight_decay` log_training_stats_at: int Write current training statistics to tensorboard every `log_training_stats_at` updates evaluate_at: int Evaluate model on training and validation set every `evaluate_at` updates. This will also check for a new best model for early stopping. ignore_missing_target_values: bool If True, missing target values will be ignored for training. This can be useful if auxiliary tasks are not available for all samples but might increase the computation time per update. """ # Append current timestamp to results directory results_directory = os.path.join( results_directory, datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')) os.makedirs(results_directory, exist_ok=True) # Read config file and set up results folder logfile = os.path.join(results_directory, 'log.txt') checkpointdir = os.path.join(results_directory, 'checkpoint') os.makedirs(checkpointdir, exist_ok=True) tensorboarddir = os.path.join(results_directory, 'tensorboard') os.makedirs(tensorboarddir, exist_ok=True) # Prepare tensorboard writer writer = SummaryWriter(log_dir=tensorboarddir) # Print all outputs to logfile and terminal tee_print = TeePrint(logfile) tprint = tee_print.tee_print try: # Set up PyTorch and numpy random seeds torch.set_num_threads(num_torch_threads) # Send model to device model.to(device) # Get optimizer (eps needs to be at about 1e-4 to be numerically stable with 16 bit float) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=l2_weight_decay, eps=1e-4) # Create a checkpoint dictionary with objects we want to have saved and loaded if needed state = dict(model=model, optimizer=optimizer, update=0, best_validation_loss=np.inf) # Setup the SaverLoader class to save/load our checkpoint dictionary to files or to RAM objects saver_loader = SaverLoader( save_dict=state, device=device, save_dir=checkpointdir, n_savefiles=1, # keep only the latest checkpoint n_inmem=1 # save checkpoint only in RAM ) # Load previous checkpoint dictionary, if load_file is specified if load_file is not None: state.update( saver_loader.load_from_file(loadname=load_file, verbose=True)) tprint(f"Loaded checkpoint from file {load_file}") update, best_validation_loss = state['update'], state[ 'best_validation_loss'] # Save checkpoint dictionary to RAM object saver_loader.save_to_ram(savename=str(update)) # # Start training # try: tprint("Training model...") update_progess_bar = tqdm(total=n_updates, disable=not show_progress, position=0, desc=f"loss={np.nan:6.4f}") while update < n_updates: for data in trainingset_dataloader: # Get samples as lists labels, inputs, sequence_lengths, counts_per_sequence, sample_ids = data # Apply attention-based sequence reduction and create minibatch with torch.no_grad(): labels, inputs, sequence_lengths, n_sequences = model.reduce_and_stack_minibatch( labels, inputs, sequence_lengths, counts_per_sequence) # Reset gradients optimizer.zero_grad() # Calculate predictions from reduced sequences, logit_outputs = model( inputs_flat=inputs, sequence_lengths_flat=sequence_lengths, n_sequences_per_bag=n_sequences) # Calculate losses pred_loss = task_definition.get_loss( raw_outputs=logit_outputs, targets=labels, ignore_missing_target_values= ignore_missing_target_values) l1reg_loss = (torch.mean( torch.stack([ p.abs().float().mean() for p in model.parameters() ]))) loss = pred_loss + l1reg_loss * l1_weight_decay # Perform update loss.backward() optimizer.step() update += 1 update_progess_bar.update() update_progess_bar.set_description( desc=f"loss={loss.item():6.4f}", refresh=True) # Add to tensorboard if update % log_training_stats_at == 0: tb_group = 'training/' # Loop through tasks and add losses to tensorboard pred_losses = task_definition.get_losses( raw_outputs=logit_outputs, targets=labels) pred_losses = pred_losses.mean( dim=1 ) # shape: (n_tasks, n_samples, 1) -> (n_tasks, 1) for task_id, task_loss in zip( task_definition.get_task_ids(), pred_losses): writer.add_scalar(tag=tb_group + f'{task_id}_loss', scalar_value=task_loss, global_step=update) writer.add_scalar(tag=tb_group + 'total_task_loss', scalar_value=pred_loss, global_step=update) writer.add_scalar(tag=tb_group + 'l1reg_loss', scalar_value=l1reg_loss, global_step=update) writer.add_scalar(tag=tb_group + 'total_loss', scalar_value=loss, global_step=update) writer.add_histogram(tag=tb_group + 'logit_outputs', values=logit_outputs, global_step=update) # Calculate scores and loss on training set and validation set if update % evaluate_at == 0 or update == n_updates or update == 1: print(" Calculating training score...") scores = evaluate( model=model, dataloader=trainingset_eval_dataloader, task_definition=task_definition, device=device) print(f" ...done!") tprint( f"[training_inference] u: {update:07d}; scores: {scores};" ) tb_group = 'training_inference/' for task_id, task_scores in scores.items(): [ writer.add_scalar(tag=tb_group + f'{task_id}/{score_name}', scalar_value=score, global_step=update) for score_name, score in task_scores.items() ] print(" Calculating validation score...") scores = evaluate( model=model, dataloader=validationset_eval_dataloader, task_definition=task_definition, device=device) scoring_loss = scores[early_stopping_target_id]['loss'] print(f" ...done!") tprint( f"[validation] u: {update:07d}; scores: {scores};") tb_group = 'validation/' for task_id, task_scores in scores.items(): [ writer.add_scalar(tag=tb_group + f'{task_id}/{score_name}', scalar_value=score, global_step=update) for score_name, score in task_scores.items() ] # If we have a new best loss on the validation set, we save the model as new best model if best_validation_loss > scoring_loss: best_validation_loss = scoring_loss tprint( f" New best validation loss for {early_stopping_target_id}: {scoring_loss}" ) # Save current state as RAM object state['update'] = update state['best_validation_loss'] = scoring_loss # Save checkpoint dictionary with currently best model to RAM saver_loader.save_to_ram(savename=str(update)) # This would save to disk every time a new best model is found, which can be slow # saver_loader.save_to_file(filename=f'best_so_far_u{update}.tar.gzip') if update >= n_updates: break update_progess_bar.close() finally: # In any case, save the current model and best model to a file saver_loader.save_to_file(filename=f'lastsave_u{update}.tar.gzip') state.update( saver_loader.load_from_ram()) # load best model so far saver_loader.save_to_file(filename=f'best_u{update}.tar.gzip') print('Finished Training!') except Exception as e: with open(logfile, 'a') as lf: print(f"Exception: {e}", file=lf) raise e finally: close_all() # Clean up
def evaluate( model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, task_definition: TaskDefinition, show_progress: bool = True, device: torch.device = torch.device('cuda:0') ) -> dict: """Compute DeepRC model scores on given dataset for tasks specified in `task_definition` Parameters ---------- model: torch.nn.Module deeprc.architectures.DeepRC or similar model as PyTorch module dataloader: torch.utils.data.DataLoader Data loader for dataset to calculate scores on task_definition: TaskDefinition TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples. show_progress: bool Show progressbar? device: torch.device Device to use for computations. E.g. `torch.device('cuda:0')` or `torch.device('cpu')`. Returns --------- scores: dict Nested dictionary of format `{task_id: {score_id: score_value}}`, e.g. `{"binary_task_1": {"auc": 0.6, "bacc": 0.5, "f1": 0.2, "loss": 0.01}}`. The scores returned are computed using the .get_scores() methods of the individual target instances (e.g. `deeprc.task_definitions.BinaryTarget()`). See `deeprc/examples/` for examples. """ with torch.no_grad(): model.to(device=device) all_raw_outputs = [] all_targets = [] for scoring_data in tqdm(dataloader, total=len(dataloader), desc="Evaluating model", disable=not show_progress): # Get samples as lists targets, inputs, sequence_lengths, counts_per_sequence, sample_ids = scoring_data # Apply attention-based sequence reduction and create minibatch targets, inputs, sequence_lengths, n_sequences = model.reduce_and_stack_minibatch( targets, inputs, sequence_lengths, counts_per_sequence) # Compute predictions from reduced sequences raw_outputs = model(inputs_flat=inputs, sequence_lengths_flat=sequence_lengths, n_sequences_per_bag=n_sequences) # Store predictions and labels all_raw_outputs.append(raw_outputs.detach()) all_targets.append(targets.detach()) # Compute scores all_raw_outputs = torch.cat(all_raw_outputs, dim=0) all_targets = torch.cat(all_targets, dim=0) scores = task_definition.get_scores(raw_outputs=all_raw_outputs, targets=all_targets) return scores
task_definition = TaskDefinition(targets=[ # Combines our sub-tasks BinaryTarget( # Add binary classification task with sigmoid output function column_name='binary_target_1', # Column name of task in metadata file true_class_value='+', # Entries with value '+' will be positive class, others will be negative class pos_weight=1., # We can up- or down-weight the positive class if the classes are imbalanced task_weight=1. # Weight of this task for the total training loss ), BinaryTarget( # Add another binary classification task column_name='binary_target_2', true_class_value='True', # Entries with value 'True' will be positive class, others will be negative class pos_weight=1./3, # Down-weights the positive class samples (e.g. if the positive class is underrepresented) task_weight=aux_task_weight ), RegressionTarget( # Add regression task with linear output function column_name='regression_target_1', # Column name of task in metadata file normalization_mean=0., normalization_std=275., # Normalize targets by ((target_value - mean) / std) task_weight=aux_task_weight # Weight of this task for the total training loss ), RegressionTarget( # Add another regression task column_name='regression_target_2', normalization_mean=0.5, normalization_std=1., task_weight=aux_task_weight ), MulticlassTarget( # Add multiclass classification task with softmax output function (=classes mutually exclusive) column_name='multiclass_target_1', # Column name of task in metadata file possible_target_values=['class_a', 'class_b', 'class_c'], # Values in task column to expect class_weights=[1., 1., 0.5], # Weight individual classes (e.g. if class 'class_c' is overrepresented) task_weight=aux_task_weight # Weight of this task for the total training loss ), MulticlassTarget( # Add another multiclass classification task column_name='multiclass_target_2', possible_target_values=['type_1', 'type_2', 'type_3', 'type_4', 'type_5'], class_weights=[1., 1., 1., 1., 1.], task_weight=aux_task_weight ) ]).to(device=device)
def cmv_dataset(dataset_path: str = None, task_definition: TaskDefinition = None, cross_validation_fold: int = 0, n_worker_processes: int = 4, batch_size: int = 4, inputformat: str = 'NCL', keep_dataset_in_ram: bool = True, sample_n_sequences: int = int(1e4), verbose: bool = True) \ -> Tuple[TaskDefinition, DataLoader, DataLoader, DataLoader, DataLoader]: """Get data loaders for category "real-world immunosequencing data" Get data loaders for training set and training-, validation-, and test-set in evaluation mode (=no random subsampling) for datasets of category "real-world immunosequencing data". This is a pre-processed version of the CMV dataset [1]_. Parameters ---------- dataset_path: str File path of dataset. If the dataset does not exist, the corresponding hdf5 container will be downloaded. Defaults to "deeprc_datasets/CMV.hdf5" task_definition: TaskDefinition TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples. cross_validation_fold : int Specify the fold of the cross-validation the dataloaders should be computed for. n_worker_processes : int Number of background processes to use for converting dataset to hdf5 container and trainingset dataloader. batch_size : int Number of repertoires per minibatch during training. inputformat : 'NCL' or 'NLC' Format of input feature array; 'NCL' -> (batchsize, channels, seq.length); 'LNC' -> (seq.length, batchsize, channels); keep_dataset_in_ram : bool It is faster to load the full hdf5 file into the RAM instead of keeping it on the disk. If False, the hdf5 file will be read from the disk and consume less RAM. sample_n_sequences : int Optional: Random sub-sampling of `sample_n_sequences` sequences per repertoire. Number of sequences per repertoire might be smaller than `sample_n_sequences` if repertoire is smaller or random indices have been drawn multiple times. If None, all sequences will be loaded for each repertoire. verbose : bool Activate verbose mode Returns --------- task_definition: TaskDefinition TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples. trainingset_dataloader: torch.utils.data.DataLoader Dataloader for trainingset with active `sample_n_sequences` (=random subsampling/dropout of repertoire sequences) trainingset_eval_dataloader: torch.utils.data.DataLoader Dataloader for trainingset with deactivated `sample_n_sequences` validationset_eval_dataloader: torch.utils.data.DataLoader Dataloader for validationset with deactivated `sample_n_sequences` testset_eval_dataloader: torch.utils.data.DataLoader Dataloader for testset with deactivated `sample_n_sequences` References ----- .. [1] Emerson, R. O., DeWitt, W. S., Vignali, M., Gravley, J.,Hu, J. K., Osborne, E. J., Desmarais, C., Klinger, M.,Carlson, C. S., Hansen, J. A., et al. Immunosequencingidentifies signatures of cytomegalovirus exposure history and hla-mediated effects on the t cell repertoire.Naturegenetics, 49(5):659, 2017 """ if dataset_path is None: dataset_path = os.path.join(os.path.dirname(deeprc.__file__), 'datasets', f'CMV') os.makedirs(dataset_path, exist_ok=True) metadata_file = os.path.join(dataset_path, f'CMV_metadata.tsv') repertoiresdata_file = os.path.join(dataset_path, f'CMV_repertoiresdata.hdf5') # Download metadata file if not os.path.exists(metadata_file): user_confirmation( f"File {metadata_file} not found. It will be downloaded now. Continue?", 'y', 'n') url_get( f"https://ml.jku.at/research/DeepRC/datasets/CMV_data/metadata/cmv_emerson_2017.tsv", metadata_file) # Download repertoire file if not os.path.exists(repertoiresdata_file): user_confirmation( f"File {repertoiresdata_file} not found. It will be downloaded now. Continue?", 'y', 'n') url_get( f"https://ml.jku.at/research/DeepRC/datasets/CMV_data/hdf5/cmv_emerson_2017.hdf5", repertoiresdata_file) # Get file for dataset splits split_file = os.path.join(os.path.dirname(deeprc.__file__), 'datasets', 'splits_used_in_paper', 'CMV_splits.pkl') with open(split_file, 'rb') as sfh: split_inds = pkl.load(sfh) # Get task_definition if task_definition is None: task_definition = TaskDefinition(targets=[ BinaryTarget(column_name='Known CMV status', true_class_value='+') ]) # Create data loaders trainingset_dataloader, trainingset_eval_dataloader, validationset_eval_dataloader, testset_eval_dataloader = \ make_dataloaders(task_definition=task_definition, metadata_file=metadata_file, repertoiresdata_path=repertoiresdata_file, split_inds=split_inds, cross_validation_fold=cross_validation_fold, n_worker_processes=n_worker_processes, batch_size=batch_size, inputformat=inputformat, keep_dataset_in_ram=keep_dataset_in_ram, sample_n_sequences=sample_n_sequences, sequence_counts_scaling_fn=log_sequence_count_scaling, metadata_file_id_column='Subject ID', verbose=verbose) return (task_definition, trainingset_dataloader, trainingset_eval_dataloader, validationset_eval_dataloader, testset_eval_dataloader)
def cmv_implanted_dataset(dataset_path: str = None, dataset_id: int = 0, task_definition: TaskDefinition = None, cross_validation_fold: int = 0, n_worker_processes: int = 4, batch_size: int = 4, inputformat: str = 'NCL', keep_dataset_in_ram: bool = True, sample_n_sequences: int = int(1e4), verbose: bool = True) \ -> Tuple[TaskDefinition, DataLoader, DataLoader, DataLoader, DataLoader]: """Get data loaders for category "real-world immunosequencing data with implanted signals". Get data loaders for training set and training-, validation-, and test-set in evaluation mode (=no random subsampling) for datasets of category "real-world immunosequencing data with implanted signals". Parameters ---------- dataset_path: str File path of dataset. If the dataset does not exist, the corresponding hdf5 container will be downloaded. Defaults to "deeprc_datasets/CMV_with_implanted_signals_{dataset_id}.hdf5" dataset_id: int ID of dataset. 0 = "One Motif 1%", 1 = "One 0.1%", 2 = "Multi 1%", 3 = "Multi 0.1%" task_definition: TaskDefinition TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples. cross_validation_fold : int Specify the fold of the cross-validation the dataloaders should be computed for. n_worker_processes : int Number of background processes to use for converting dataset to hdf5 container and trainingset dataloader. batch_size : int Number of repertoires per minibatch during training. inputformat : 'NCL' or 'NLC' Format of input feature array; 'NCL' -> (batchsize, channels, seq.length); 'LNC' -> (seq.length, batchsize, channels); keep_dataset_in_ram : bool It is faster to load the full hdf5 file into the RAM instead of keeping it on the disk. If False, the hdf5 file will be read from the disk and consume less RAM. sample_n_sequences : int Optional: Random sub-sampling of `sample_n_sequences` sequences per repertoire. Number of sequences per repertoire might be smaller than `sample_n_sequences` if repertoire is smaller or random indices have been drawn multiple times. If None, all sequences will be loaded for each repertoire. verbose : bool Activate verbose mode Returns --------- task_definition: TaskDefinition TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples. trainingset_dataloader: torch.utils.data.DataLoader Dataloader for trainingset with active `sample_n_sequences` (=random subsampling/dropout of repertoire sequences) trainingset_eval_dataloader: torch.utils.data.DataLoader Dataloader for trainingset with deactivated `sample_n_sequences` validationset_eval_dataloader: torch.utils.data.DataLoader Dataloader for validationset with deactivated `sample_n_sequences` testset_eval_dataloader: torch.utils.data.DataLoader Dataloader for testset with deactivated `sample_n_sequences` """ if dataset_path is None: dataset_path = os.path.join(os.path.dirname(deeprc.__file__), 'datasets', f'CMV_with_implanted_signals') os.makedirs(dataset_path, exist_ok=True) metadata_file = os.path.join( dataset_path, f'CMV_with_implanted_signals_{dataset_id}_metadata.tsv') repertoiresdata_file = os.path.join( dataset_path, f'CMV_with_implanted_signals_{dataset_id}_repertoiresdata.hdf5') # Download metadata file if not os.path.exists(metadata_file): user_confirmation( f"File {metadata_file} not found. It will be downloaded now. Continue?", 'y', 'n') # url_get(f"https://ml.jku.at/research/DeepRC/datasets/CMV_data_with_implanted_signals/metadata/implanted_signals_{dataset_id}.csv", # metadata_file) url_get( f"https://cloud.ml.jku.at/s/KQDAdHjHpdn3pzg/download?path=/datasets/CMV_data_with_implanted_signals/metadata&files=implanted_signals_{dataset_id}.tsv", metadata_file) # Download repertoire file if not os.path.exists(repertoiresdata_file): user_confirmation( f"File {repertoiresdata_file} not found. It will be downloaded now. Continue?", 'y', 'n') url_get( f"https://ml.jku.at/research/DeepRC/datasets/CMV_data_with_implanted_signals/hdf5/implanted_signals_{dataset_id}.hdf5", repertoiresdata_file) # Get file for dataset splits split_file = os.path.join(os.path.dirname(deeprc.__file__), 'datasets', 'splits_used_in_paper', 'CMV_with_implanted_signals.pkl') with open(split_file, 'rb') as sfh: split_inds = pkl.load(sfh) # Get task_definition if task_definition is None: task_definition = TaskDefinition(targets=[ BinaryTarget(column_name='status', true_class_value='True') ]) # Create data loaders trainingset_dataloader, trainingset_eval_dataloader, validationset_eval_dataloader, testset_eval_dataloader = \ make_dataloaders(task_definition=task_definition, metadata_file=metadata_file, repertoiresdata_path=repertoiresdata_file, split_inds=split_inds, cross_validation_fold=cross_validation_fold, n_worker_processes=n_worker_processes, batch_size=batch_size, inputformat=inputformat, keep_dataset_in_ram=keep_dataset_in_ram, sample_n_sequences=sample_n_sequences, sequence_counts_scaling_fn=no_sequence_count_scaling, verbose=verbose) return (task_definition, trainingset_dataloader, trainingset_eval_dataloader, validationset_eval_dataloader, testset_eval_dataloader)
args = parser.parse_args() # Set computation device device = torch.device(args.device) # Set random seed (will still be non-deterministic due to multiprocessing but weight initialization will be the same) torch.manual_seed(args.rnd_seed) np.random.seed(args.rnd_seed) # # Create Task definitions # # Assume we want to train on 1 main task as binary task at column 'binary_target_1' of our metadata file. task_definition = TaskDefinition(targets=[ # Combines our sub-tasks BinaryTarget( # Add binary classification task with sigmoid output function column_name='binary_target_1', # Column name of task in metadata file true_class_value='+', # Entries with value '+' will be positive class, others will be negative class pos_weight=1., # We can up- or down-weight the positive class if the classes are imbalanced ), ]).to(device=device) # # Get dataset # # Get data loaders for training set and training-, validation-, and test-set in evaluation mode (=no random subsampling) trainingset, trainingset_eval, validationset_eval, testset_eval = make_dataloaders( task_definition=task_definition, metadata_file="../datasets/example_dataset/metadata.tsv", repertoiresdata_path="../datasets/example_dataset/repertoires", metadata_file_id_column='ID', sequence_column='amino_acid',