コード例 #1
0
def train(model: torch.nn.Module,
          task_definition: TaskDefinition,
          early_stopping_target_id: str,
          trainingset_dataloader: torch.utils.data.DataLoader,
          trainingset_eval_dataloader: torch.utils.data.DataLoader,
          validationset_eval_dataloader: torch.utils.data.DataLoader,
          results_directory: str = "results",
          n_updates: int = int(1e5),
          show_progress: bool = True,
          load_file: str = None,
          device: torch.device = torch.device('cuda:0'),
          num_torch_threads: int = 3,
          learning_rate: float = 1e-4,
          l1_weight_decay: float = 0,
          l2_weight_decay: float = 0,
          log_training_stats_at: int = int(1e2),
          evaluate_at: int = int(5e3),
          ignore_missing_target_values: bool = True):
    """Train a DeepRC model on a given dataset on tasks specified in `task_definition`
     
     Model with lowest validation set loss on target `early_stopping_target_id` will be taken as final model (=early
     stopping). Model performance on validation set will be evaluated every `evaluate_at` updates.
     Trained model, logfile, and tensorboard files will be stored in `results_directory`.
    
    See `deeprc/examples/` for examples.
    
    Parameters
    ----------
    model: torch.nn.Module
         deeprc.architectures.DeepRC or similar model as PyTorch module
    task_definition: TaskDefinition
        TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples.
    early_stopping_target_id: str
        ID of task in TaskDefinition object to use for early stopping.
    trainingset_dataloader: torch.utils.data.DataLoader
         Data loader for training
    trainingset_eval_dataloader: torch.utils.data.DataLoader
         Data loader for evaluation on training set (=no random subsampling)
    validationset_eval_dataloader: torch.utils.data.DataLoader
         Data loader for evaluation on validation set (=no random subsampling).
         Will be used for early-stopping.
    results_directory: str
         Directory to save checkpoint of best trained model, logfile, and tensorboard files in
    n_updates: int
         Number of updates to train for
    show_progress: bool
         Show progressbar?
    load_file: str
         Path to load checkpoint of previously saved model from
    device: torch.device
         Device to use for computations. E.g. `torch.device('cuda:0')` or `torch.device('cpu')`.
         Currently, only devices which support 16 bit float are supported.
    num_torch_threads: int
         Number of parallel threads to allow PyTorch
    learning_rate: float
         Learning rate for adam optimizer
    l1_weight_decay: float
         l1 weight decay factor. l1 weight penalty will be added to loss, scaled by `l1_weight_decay`
    l2_weight_decay: float
         l2 weight decay factor. l2 weight penalty will be added to loss, scaled by `l2_weight_decay`
    log_training_stats_at: int
         Write current training statistics to tensorboard every `log_training_stats_at` updates
    evaluate_at: int
         Evaluate model on training and validation set every `evaluate_at` updates.
         This will also check for a new best model for early stopping.
    ignore_missing_target_values: bool
         If True, missing target values will be ignored for training. This can be useful if auxiliary tasks are not
         available for all samples but might increase the computation time per update.
    """
    # Append current timestamp to results directory
    results_directory = os.path.join(
        results_directory,
        datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(results_directory, exist_ok=True)

    # Read config file and set up results folder
    logfile = os.path.join(results_directory, 'log.txt')
    checkpointdir = os.path.join(results_directory, 'checkpoint')
    os.makedirs(checkpointdir, exist_ok=True)
    tensorboarddir = os.path.join(results_directory, 'tensorboard')
    os.makedirs(tensorboarddir, exist_ok=True)

    # Prepare tensorboard writer
    writer = SummaryWriter(log_dir=tensorboarddir)

    # Print all outputs to logfile and terminal
    tee_print = TeePrint(logfile)
    tprint = tee_print.tee_print

    try:
        # Set up PyTorch and numpy random seeds
        torch.set_num_threads(num_torch_threads)

        # Send model to device
        model.to(device)

        # Get optimizer (eps needs to be at about 1e-4 to be numerically stable with 16 bit float)
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=l2_weight_decay,
                                     eps=1e-4)

        # Create a checkpoint dictionary with objects we want to have saved and loaded if needed
        state = dict(model=model,
                     optimizer=optimizer,
                     update=0,
                     best_validation_loss=np.inf)

        # Setup the SaverLoader class to save/load our checkpoint dictionary to files or to RAM objects
        saver_loader = SaverLoader(
            save_dict=state,
            device=device,
            save_dir=checkpointdir,
            n_savefiles=1,  # keep only the latest checkpoint
            n_inmem=1  # save checkpoint only in RAM
        )

        # Load previous checkpoint dictionary, if load_file is specified
        if load_file is not None:
            state.update(
                saver_loader.load_from_file(loadname=load_file, verbose=True))
            tprint(f"Loaded checkpoint from file {load_file}")
        update, best_validation_loss = state['update'], state[
            'best_validation_loss']

        # Save checkpoint dictionary to RAM object
        saver_loader.save_to_ram(savename=str(update))

        #
        # Start training
        #
        try:
            tprint("Training model...")
            update_progess_bar = tqdm(total=n_updates,
                                      disable=not show_progress,
                                      position=0,
                                      desc=f"loss={np.nan:6.4f}")
            while update < n_updates:
                for data in trainingset_dataloader:
                    # Get samples as lists
                    labels, inputs, sequence_lengths, counts_per_sequence, sample_ids = data

                    # Apply attention-based sequence reduction and create minibatch
                    with torch.no_grad():
                        labels, inputs, sequence_lengths, n_sequences = model.reduce_and_stack_minibatch(
                            labels, inputs, sequence_lengths,
                            counts_per_sequence)

                    # Reset gradients
                    optimizer.zero_grad()

                    # Calculate predictions from reduced sequences,
                    logit_outputs = model(
                        inputs_flat=inputs,
                        sequence_lengths_flat=sequence_lengths,
                        n_sequences_per_bag=n_sequences)

                    # Calculate losses
                    pred_loss = task_definition.get_loss(
                        raw_outputs=logit_outputs,
                        targets=labels,
                        ignore_missing_target_values=
                        ignore_missing_target_values)
                    l1reg_loss = (torch.mean(
                        torch.stack([
                            p.abs().float().mean() for p in model.parameters()
                        ])))
                    loss = pred_loss + l1reg_loss * l1_weight_decay

                    # Perform update
                    loss.backward()
                    optimizer.step()

                    update += 1
                    update_progess_bar.update()
                    update_progess_bar.set_description(
                        desc=f"loss={loss.item():6.4f}", refresh=True)

                    # Add to tensorboard
                    if update % log_training_stats_at == 0:
                        tb_group = 'training/'
                        # Loop through tasks and add losses to tensorboard
                        pred_losses = task_definition.get_losses(
                            raw_outputs=logit_outputs, targets=labels)
                        pred_losses = pred_losses.mean(
                            dim=1
                        )  # shape: (n_tasks, n_samples, 1) -> (n_tasks, 1)
                        for task_id, task_loss in zip(
                                task_definition.get_task_ids(), pred_losses):
                            writer.add_scalar(tag=tb_group + f'{task_id}_loss',
                                              scalar_value=task_loss,
                                              global_step=update)
                        writer.add_scalar(tag=tb_group + 'total_task_loss',
                                          scalar_value=pred_loss,
                                          global_step=update)
                        writer.add_scalar(tag=tb_group + 'l1reg_loss',
                                          scalar_value=l1reg_loss,
                                          global_step=update)
                        writer.add_scalar(tag=tb_group + 'total_loss',
                                          scalar_value=loss,
                                          global_step=update)
                        writer.add_histogram(tag=tb_group + 'logit_outputs',
                                             values=logit_outputs,
                                             global_step=update)

                    # Calculate scores and loss on training set and validation set
                    if update % evaluate_at == 0 or update == n_updates or update == 1:
                        print("  Calculating training score...")
                        scores = evaluate(
                            model=model,
                            dataloader=trainingset_eval_dataloader,
                            task_definition=task_definition,
                            device=device)
                        print(f" ...done!")
                        tprint(
                            f"[training_inference] u: {update:07d}; scores: {scores};"
                        )

                        tb_group = 'training_inference/'
                        for task_id, task_scores in scores.items():
                            [
                                writer.add_scalar(tag=tb_group +
                                                  f'{task_id}/{score_name}',
                                                  scalar_value=score,
                                                  global_step=update)
                                for score_name, score in task_scores.items()
                            ]

                        print("  Calculating validation score...")
                        scores = evaluate(
                            model=model,
                            dataloader=validationset_eval_dataloader,
                            task_definition=task_definition,
                            device=device)
                        scoring_loss = scores[early_stopping_target_id]['loss']

                        print(f" ...done!")
                        tprint(
                            f"[validation] u: {update:07d}; scores: {scores};")

                        tb_group = 'validation/'
                        for task_id, task_scores in scores.items():
                            [
                                writer.add_scalar(tag=tb_group +
                                                  f'{task_id}/{score_name}',
                                                  scalar_value=score,
                                                  global_step=update)
                                for score_name, score in task_scores.items()
                            ]

                        # If we have a new best loss on the validation set, we save the model as new best model
                        if best_validation_loss > scoring_loss:
                            best_validation_loss = scoring_loss
                            tprint(
                                f"  New best validation loss for {early_stopping_target_id}: {scoring_loss}"
                            )
                            # Save current state as RAM object
                            state['update'] = update
                            state['best_validation_loss'] = scoring_loss
                            # Save checkpoint dictionary with currently best model to RAM
                            saver_loader.save_to_ram(savename=str(update))
                            # This would save to disk every time a new best model is found, which can be slow
                            # saver_loader.save_to_file(filename=f'best_so_far_u{update}.tar.gzip')

                    if update >= n_updates:
                        break
            update_progess_bar.close()
        finally:
            # In any case, save the current model and best model to a file
            saver_loader.save_to_file(filename=f'lastsave_u{update}.tar.gzip')
            state.update(
                saver_loader.load_from_ram())  # load best model so far
            saver_loader.save_to_file(filename=f'best_u{update}.tar.gzip')
            print('Finished Training!')
    except Exception as e:
        with open(logfile, 'a') as lf:
            print(f"Exception: {e}", file=lf)
        raise e
    finally:
        close_all()  # Clean up
コード例 #2
0
def evaluate(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    task_definition: TaskDefinition,
    show_progress: bool = True,
    device: torch.device = torch.device('cuda:0')
) -> dict:
    """Compute DeepRC model scores on given dataset for tasks specified in `task_definition`
    
    Parameters
    ----------
    model: torch.nn.Module
         deeprc.architectures.DeepRC or similar model as PyTorch module
    dataloader: torch.utils.data.DataLoader
         Data loader for dataset to calculate scores on
    task_definition: TaskDefinition
        TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples.
    show_progress: bool
         Show progressbar?
    device: torch.device
         Device to use for computations. E.g. `torch.device('cuda:0')` or `torch.device('cpu')`.
    
    Returns
    ---------
    scores: dict
        Nested dictionary of format `{task_id: {score_id: score_value}}`, e.g.
        `{"binary_task_1": {"auc": 0.6, "bacc": 0.5, "f1": 0.2, "loss": 0.01}}`. The scores returned are computed using
        the .get_scores() methods of the individual target instances (e.g. `deeprc.task_definitions.BinaryTarget()`).
        See `deeprc/examples/` for examples.
    """
    with torch.no_grad():
        model.to(device=device)
        all_raw_outputs = []
        all_targets = []

        for scoring_data in tqdm(dataloader,
                                 total=len(dataloader),
                                 desc="Evaluating model",
                                 disable=not show_progress):

            # Get samples as lists
            targets, inputs, sequence_lengths, counts_per_sequence, sample_ids = scoring_data

            # Apply attention-based sequence reduction and create minibatch
            targets, inputs, sequence_lengths, n_sequences = model.reduce_and_stack_minibatch(
                targets, inputs, sequence_lengths, counts_per_sequence)

            # Compute predictions from reduced sequences
            raw_outputs = model(inputs_flat=inputs,
                                sequence_lengths_flat=sequence_lengths,
                                n_sequences_per_bag=n_sequences)

            # Store predictions and labels
            all_raw_outputs.append(raw_outputs.detach())
            all_targets.append(targets.detach())

        # Compute scores
        all_raw_outputs = torch.cat(all_raw_outputs, dim=0)
        all_targets = torch.cat(all_targets, dim=0)

        scores = task_definition.get_scores(raw_outputs=all_raw_outputs,
                                            targets=all_targets)
    return scores
コード例 #3
0
task_definition = TaskDefinition(targets=[  # Combines our sub-tasks
    BinaryTarget(  # Add binary classification task with sigmoid output function
            column_name='binary_target_1',  # Column name of task in metadata file
            true_class_value='+',  # Entries with value '+' will be positive class, others will be negative class
            pos_weight=1.,  # We can up- or down-weight the positive class if the classes are imbalanced
            task_weight=1.  # Weight of this task for the total training loss
    ),
    BinaryTarget(  # Add another binary classification task
            column_name='binary_target_2',
            true_class_value='True',  # Entries with value 'True' will be positive class, others will be negative class
            pos_weight=1./3,  # Down-weights the positive class samples (e.g. if the positive class is underrepresented)
            task_weight=aux_task_weight
    ),
    RegressionTarget(  # Add regression task with linear output function
            column_name='regression_target_1',  # Column name of task in metadata file
            normalization_mean=0., normalization_std=275.,  # Normalize targets by ((target_value - mean) / std)
            task_weight=aux_task_weight  # Weight of this task for the total training loss
    ),
    RegressionTarget(  # Add another regression task
            column_name='regression_target_2',
            normalization_mean=0.5, normalization_std=1.,
            task_weight=aux_task_weight
    ),
    MulticlassTarget(  # Add multiclass classification task with softmax output function (=classes mutually exclusive)
            column_name='multiclass_target_1',  # Column name of task in metadata file
            possible_target_values=['class_a', 'class_b', 'class_c'],  # Values in task column to expect
            class_weights=[1., 1., 0.5],  # Weight individual classes (e.g. if class 'class_c' is overrepresented)
            task_weight=aux_task_weight  # Weight of this task for the total training loss
    ),
    MulticlassTarget(  # Add another multiclass classification task
            column_name='multiclass_target_2',
            possible_target_values=['type_1', 'type_2', 'type_3', 'type_4', 'type_5'],
            class_weights=[1., 1., 1., 1., 1.],
            task_weight=aux_task_weight
    )
]).to(device=device)
コード例 #4
0
def cmv_dataset(dataset_path: str = None, task_definition: TaskDefinition = None,
                cross_validation_fold: int = 0, n_worker_processes: int = 4, batch_size: int = 4,
                inputformat: str = 'NCL', keep_dataset_in_ram: bool = True,
                sample_n_sequences: int = int(1e4), verbose: bool = True) \
        -> Tuple[TaskDefinition, DataLoader, DataLoader, DataLoader, DataLoader]:
    """Get data loaders for category "real-world immunosequencing data"
     
     Get data loaders for training set and training-, validation-, and test-set in evaluation mode
     (=no random subsampling) for datasets of category "real-world immunosequencing data".
     This is a pre-processed version of the CMV dataset [1]_.
     
    Parameters
    ----------
    dataset_path: str
        File path of dataset. If the dataset does not exist, the corresponding hdf5 container will be downloaded.
        Defaults to "deeprc_datasets/CMV.hdf5"
    task_definition: TaskDefinition
        TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples.
    cross_validation_fold : int
        Specify the fold of the cross-validation the dataloaders should be computed for.
    n_worker_processes : int
        Number of background processes to use for converting dataset to hdf5 container and trainingset dataloader.
    batch_size : int
        Number of repertoires per minibatch during training.
    inputformat : 'NCL' or 'NLC'
        Format of input feature array;
        'NCL' -> (batchsize, channels, seq.length);
        'LNC' -> (seq.length, batchsize, channels);
    keep_dataset_in_ram : bool
        It is faster to load the full hdf5 file into the RAM instead of keeping it on the disk.
        If False, the hdf5 file will be read from the disk and consume less RAM.
    sample_n_sequences : int
        Optional: Random sub-sampling of `sample_n_sequences` sequences per repertoire.
        Number of sequences per repertoire might be smaller than `sample_n_sequences` if repertoire is smaller or
        random indices have been drawn multiple times.
        If None, all sequences will be loaded for each repertoire.
    verbose : bool
        Activate verbose mode
    
    Returns
    ---------
    task_definition: TaskDefinition
        TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples.
    trainingset_dataloader: torch.utils.data.DataLoader
        Dataloader for trainingset with active `sample_n_sequences` (=random subsampling/dropout of repertoire
        sequences)
    trainingset_eval_dataloader: torch.utils.data.DataLoader
        Dataloader for trainingset with deactivated `sample_n_sequences`
    validationset_eval_dataloader: torch.utils.data.DataLoader
        Dataloader for validationset with deactivated `sample_n_sequences`
    testset_eval_dataloader: torch.utils.data.DataLoader
        Dataloader for testset with deactivated `sample_n_sequences`
    
    References
    -----
    .. [1] Emerson, R. O., DeWitt, W. S., Vignali, M., Gravley, J.,Hu, J. K., Osborne, E. J., Desmarais, C., Klinger,
     M.,Carlson, C. S., Hansen, J. A., et al. Immunosequencingidentifies signatures of cytomegalovirus exposure history
     and hla-mediated effects on the t cell repertoire.Naturegenetics, 49(5):659, 2017
    """
    if dataset_path is None:
        dataset_path = os.path.join(os.path.dirname(deeprc.__file__),
                                    'datasets', f'CMV')
    os.makedirs(dataset_path, exist_ok=True)
    metadata_file = os.path.join(dataset_path, f'CMV_metadata.tsv')
    repertoiresdata_file = os.path.join(dataset_path,
                                        f'CMV_repertoiresdata.hdf5')

    # Download metadata file
    if not os.path.exists(metadata_file):
        user_confirmation(
            f"File {metadata_file} not found. It will be downloaded now. Continue?",
            'y', 'n')
        url_get(
            f"https://ml.jku.at/research/DeepRC/datasets/CMV_data/metadata/cmv_emerson_2017.tsv",
            metadata_file)

    # Download repertoire file
    if not os.path.exists(repertoiresdata_file):
        user_confirmation(
            f"File {repertoiresdata_file} not found. It will be downloaded now. Continue?",
            'y', 'n')
        url_get(
            f"https://ml.jku.at/research/DeepRC/datasets/CMV_data/hdf5/cmv_emerson_2017.hdf5",
            repertoiresdata_file)

    # Get file for dataset splits
    split_file = os.path.join(os.path.dirname(deeprc.__file__), 'datasets',
                              'splits_used_in_paper', 'CMV_splits.pkl')
    with open(split_file, 'rb') as sfh:
        split_inds = pkl.load(sfh)

    # Get task_definition
    if task_definition is None:
        task_definition = TaskDefinition(targets=[
            BinaryTarget(column_name='Known CMV status', true_class_value='+')
        ])

    # Create data loaders
    trainingset_dataloader, trainingset_eval_dataloader, validationset_eval_dataloader, testset_eval_dataloader = \
        make_dataloaders(task_definition=task_definition, metadata_file=metadata_file,
                         repertoiresdata_path=repertoiresdata_file, split_inds=split_inds,
                         cross_validation_fold=cross_validation_fold, n_worker_processes=n_worker_processes,
                         batch_size=batch_size, inputformat=inputformat, keep_dataset_in_ram=keep_dataset_in_ram,
                         sample_n_sequences=sample_n_sequences, sequence_counts_scaling_fn=log_sequence_count_scaling,
                         metadata_file_id_column='Subject ID', verbose=verbose)
    return (task_definition, trainingset_dataloader,
            trainingset_eval_dataloader, validationset_eval_dataloader,
            testset_eval_dataloader)
コード例 #5
0
def cmv_implanted_dataset(dataset_path: str = None, dataset_id: int = 0, task_definition: TaskDefinition = None,
                          cross_validation_fold: int = 0, n_worker_processes: int = 4, batch_size: int = 4,
                          inputformat: str = 'NCL', keep_dataset_in_ram: bool = True,
                          sample_n_sequences: int = int(1e4),  verbose: bool = True) \
        -> Tuple[TaskDefinition, DataLoader, DataLoader, DataLoader, DataLoader]:
    """Get data loaders for category "real-world immunosequencing data with implanted signals".
     
     Get data loaders for training set and training-, validation-, and test-set in evaluation mode
     (=no random subsampling) for datasets of category "real-world immunosequencing data with implanted signals".
     
    Parameters
    ----------
    dataset_path: str
        File path of dataset. If the dataset does not exist, the corresponding hdf5 container will be downloaded.
        Defaults to "deeprc_datasets/CMV_with_implanted_signals_{dataset_id}.hdf5"
    dataset_id: int
        ID of dataset.
        0 = "One Motif 1%", 1 = "One 0.1%", 2 = "Multi 1%", 3 = "Multi 0.1%"
    task_definition: TaskDefinition
        TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples.
    cross_validation_fold : int
        Specify the fold of the cross-validation the dataloaders should be computed for.
    n_worker_processes : int
        Number of background processes to use for converting dataset to hdf5 container and trainingset dataloader.
    batch_size : int
        Number of repertoires per minibatch during training.
    inputformat : 'NCL' or 'NLC'
        Format of input feature array;
        'NCL' -> (batchsize, channels, seq.length);
        'LNC' -> (seq.length, batchsize, channels);
    keep_dataset_in_ram : bool
        It is faster to load the full hdf5 file into the RAM instead of keeping it on the disk.
        If False, the hdf5 file will be read from the disk and consume less RAM.
    sample_n_sequences : int
        Optional: Random sub-sampling of `sample_n_sequences` sequences per repertoire.
        Number of sequences per repertoire might be smaller than `sample_n_sequences` if repertoire is smaller or
        random indices have been drawn multiple times.
        If None, all sequences will be loaded for each repertoire.
    verbose : bool
        Activate verbose mode
    
    Returns
    ---------
    task_definition: TaskDefinition
        TaskDefinition object containing the tasks to train the DeepRC model on. See `deeprc/examples/` for examples.
    trainingset_dataloader: torch.utils.data.DataLoader
        Dataloader for trainingset with active `sample_n_sequences` (=random subsampling/dropout of repertoire
        sequences)
    trainingset_eval_dataloader: torch.utils.data.DataLoader
        Dataloader for trainingset with deactivated `sample_n_sequences`
    validationset_eval_dataloader: torch.utils.data.DataLoader
        Dataloader for validationset with deactivated `sample_n_sequences`
    testset_eval_dataloader: torch.utils.data.DataLoader
        Dataloader for testset with deactivated `sample_n_sequences`
    """
    if dataset_path is None:
        dataset_path = os.path.join(os.path.dirname(deeprc.__file__),
                                    'datasets', f'CMV_with_implanted_signals')
    os.makedirs(dataset_path, exist_ok=True)
    metadata_file = os.path.join(
        dataset_path, f'CMV_with_implanted_signals_{dataset_id}_metadata.tsv')
    repertoiresdata_file = os.path.join(
        dataset_path,
        f'CMV_with_implanted_signals_{dataset_id}_repertoiresdata.hdf5')

    # Download metadata file
    if not os.path.exists(metadata_file):
        user_confirmation(
            f"File {metadata_file} not found. It will be downloaded now. Continue?",
            'y', 'n')
        # url_get(f"https://ml.jku.at/research/DeepRC/datasets/CMV_data_with_implanted_signals/metadata/implanted_signals_{dataset_id}.csv",
        #         metadata_file)
        url_get(
            f"https://cloud.ml.jku.at/s/KQDAdHjHpdn3pzg/download?path=/datasets/CMV_data_with_implanted_signals/metadata&files=implanted_signals_{dataset_id}.tsv",
            metadata_file)

    # Download repertoire file
    if not os.path.exists(repertoiresdata_file):
        user_confirmation(
            f"File {repertoiresdata_file} not found. It will be downloaded now. Continue?",
            'y', 'n')
        url_get(
            f"https://ml.jku.at/research/DeepRC/datasets/CMV_data_with_implanted_signals/hdf5/implanted_signals_{dataset_id}.hdf5",
            repertoiresdata_file)

    # Get file for dataset splits
    split_file = os.path.join(os.path.dirname(deeprc.__file__), 'datasets',
                              'splits_used_in_paper',
                              'CMV_with_implanted_signals.pkl')
    with open(split_file, 'rb') as sfh:
        split_inds = pkl.load(sfh)

    # Get task_definition
    if task_definition is None:
        task_definition = TaskDefinition(targets=[
            BinaryTarget(column_name='status', true_class_value='True')
        ])

    # Create data loaders
    trainingset_dataloader, trainingset_eval_dataloader, validationset_eval_dataloader, testset_eval_dataloader = \
        make_dataloaders(task_definition=task_definition, metadata_file=metadata_file,
                         repertoiresdata_path=repertoiresdata_file, split_inds=split_inds,
                         cross_validation_fold=cross_validation_fold, n_worker_processes=n_worker_processes,
                         batch_size=batch_size, inputformat=inputformat, keep_dataset_in_ram=keep_dataset_in_ram,
                         sample_n_sequences=sample_n_sequences, sequence_counts_scaling_fn=no_sequence_count_scaling,
                         verbose=verbose)
    return (task_definition, trainingset_dataloader,
            trainingset_eval_dataloader, validationset_eval_dataloader,
            testset_eval_dataloader)
コード例 #6
0
args = parser.parse_args()
# Set computation device
device = torch.device(args.device)
# Set random seed (will still be non-deterministic due to multiprocessing but weight initialization will be the same)
torch.manual_seed(args.rnd_seed)
np.random.seed(args.rnd_seed)


#
# Create Task definitions
#
# Assume we want to train on 1 main task as binary task at column 'binary_target_1' of our metadata file.
task_definition = TaskDefinition(targets=[  # Combines our sub-tasks
    BinaryTarget(  # Add binary classification task with sigmoid output function
            column_name='binary_target_1',  # Column name of task in metadata file
            true_class_value='+',  # Entries with value '+' will be positive class, others will be negative class
            pos_weight=1.,  # We can up- or down-weight the positive class if the classes are imbalanced
    ),
]).to(device=device)


#
# Get dataset
#
# Get data loaders for training set and training-, validation-, and test-set in evaluation mode (=no random subsampling)
trainingset, trainingset_eval, validationset_eval, testset_eval = make_dataloaders(
        task_definition=task_definition,
        metadata_file="../datasets/example_dataset/metadata.tsv",
        repertoiresdata_path="../datasets/example_dataset/repertoires",
        metadata_file_id_column='ID',
        sequence_column='amino_acid',