Exemple #1
0
def dataspec_stats(dataspec, regions=None, sample=None, peak_width=1000):
    """Compute the stats about the tracks
    """
    import random
    from pybedtools import BedTool
    from bpnet.preproc import resize_interval
    from genomelake.extractors import FastaExtractor

    ds = DataSpec.load(dataspec)

    if regions is not None:
        regions = list(BedTool(regions))
    else:
        regions = ds.get_all_regions()

    if sample is not None and sample < len(regions):
        logger.info(
            f"Using {sample} randomly sampled regions instead of {len(regions)}"
        )
        regions = random.sample(regions, k=sample)

    # resize the regions
    regions = [
        resize_interval(interval, peak_width, ignore_strand=True)
        for interval in regions
    ]

    base_freq = FastaExtractor(ds.fasta_file)(regions).mean(axis=(0, 1))

    count_stats = _track_stats(ds.load_counts(regions, progbar=True))
    bias_count_stats = _track_stats(ds.load_bias_counts(regions, progbar=True))

    print("")
    print("Base frequency")
    for i, base in enumerate(['A', 'C', 'G', 'T']):
        print(f"- {base}: {base_freq[i]}")
    print("")
    print("Count stats")
    for task, stats in count_stats.items():
        print(f"- {task}")
        for stat_key, stat_value in stats.items():
            print(f"  {stat_key}: {stat_value}")
    print("")
    print("Bias stats")
    for task, stats in bias_count_stats.items():
        print(f"- {task}")
        for stat_key, stat_value in stats.items():
            print(f"  {stat_key}: {stat_value}")

    lamb = np.mean([v["total median"] for v in count_stats.values()]) / 10
    print("")
    print(
        f"We recommend to set lambda=total_count_median / 10 = {lamb:.2f} (default=10) in `bpnet train --override=` "
        "to put 5x more weight on profile prediction than on total count prediction."
    )
Exemple #2
0
 def from_mdir(cls, model_dir):
     from bpnet.seqmodel import SeqModel
     # figure out also the fasta_file if present (from dataspec)
     from bpnet.dataspecs import DataSpec
     ds_path = os.path.join(model_dir, "dataspec.yml")
     if os.path.exists(ds_path):
         ds = DataSpec.load(ds_path)
         fasta_file = ds.fasta_file
     else:
         fasta_file = None
     return cls(SeqModel.from_mdir(model_dir), fasta_file=fasta_file)
Exemple #3
0
def bpnet_export_bw(
        model_dir,
        output_prefix,
        fasta_file=None,
        regions=None,
        contrib_method='grad',
        contrib_wildcard='*/profile/wn,*/counts/pre-act',  # specifies which contrib. scores to compute
        batch_size=256,
        scale_contribution=False,
        flip_negative_strand=False,
        gpu=0,
        memfrac_gpu=0.45):
    """Export model predictions and contribution scores to big-wig files
    """
    from pybedtools import BedTool
    from bpnet.modisco.core import Seqlet
    output_dir = os.path.dirname(output_prefix)
    add_file_logging(output_dir, logger, 'bpnet-export-bw')
    os.makedirs(output_dir, exist_ok=True)
    if gpu is not None:
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac_gpu)

    logger.info("Load model")

    bp = BPNetSeqModel.from_mdir(model_dir)

    if regions is not None:
        logger.info(
            f"Computing predictions and contribution scores for provided regions: {regions}"
        )
        regions = list(BedTool(regions))
    else:
        logger.info("--regions not provided. Using regions from dataspec.yml")
        ds = DataSpec.load(os.path.join(model_dir, 'dataspec.yml'))
        regions = ds.get_all_regions()

    seqlen = bp.input_seqlen()
    logger.info(
        f"Resizing regions (fix=center) to model's input width of: {seqlen}")
    regions = [resize_interval(interval, seqlen) for interval in regions]
    logger.info("Sort the bed file")
    regions = list(BedTool(regions).sort())

    bp.export_bw(regions=regions,
                 output_prefix=output_prefix,
                 contrib_method=contrib_method,
                 fasta_file=fasta_file,
                 pred_summaries=contrib_wildcard.replace("*/", "").split(","),
                 batch_size=batch_size,
                 scale_contribution=scale_contribution,
                 flip_negative_strand=flip_negative_strand,
                 chromosomes=None)  # infer chromosomes from the fasta file
Exemple #4
0
def bpnet_contrib(
        model_dir,
        output_file,
        method="grad",
        dataspec=None,
        regions=None,
        fasta_file=None,  # alternative to dataspec
        shuffle_seq=False,
        shuffle_regions=False,
        max_regions=None,
        # reference='zeroes', # Currently the only option
        # peak_width=1000,  # automatically inferred from 'config.gin.json'
        # seq_width=None,
        contrib_wildcard='*/profile/wn,*/counts/pre-act',  # specifies which contrib. scores to compute
        batch_size=512,
        gpu=0,
        memfrac_gpu=0.45,
        num_workers=10,
        storage_chunk_size=512,
        exclude_chr='',
        include_chr='',
        overwrite=False,
        skip_bias=False):
    """Run contribution scores for a BPNet model
    """
    from bpnet.extractors import _chrom_sizes
    add_file_logging(os.path.dirname(output_file), logger, 'bpnet-contrib')
    if gpu is not None:
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac_gpu)
    else:
        # Don't use any GPU's
        os.environ['CUDA_VISIBLE_DEVICES'] = ''

    if os.path.exists(output_file):
        if overwrite:
            os.remove(output_file)
        else:
            raise ValueError(
                f"File exists {output_file}. Use overwrite=True to overwrite it"
            )

    config = read_json(os.path.join(model_dir, 'config.gin.json'))
    seq_width = config['seq_width']
    peak_width = config['seq_width']

    # NOTE - seq_width has to be the same for the input and the target
    #
    # infer from the command line
    # if seq_width is None:
    #     logger.info("Using seq_width = peak_width")
    #     seq_width = peak_width

    # # make sure these are int's
    # seq_width = int(seq_width)
    # peak_width = int(peak_width)

    # Split
    contrib_wildcards = contrib_wildcard.split(",")

    # Allow chr inclusion / exclusion
    if exclude_chr:
        exclude_chr = exclude_chr.split(",")
    else:
        exclude_chr = None
    if include_chr:
        include_chr = include_chr.split(",")
    else:
        include_chr = None

    logger.info("Loading the config files")
    model_dir = Path(model_dir)

    logger.info("Creating the dataset")
    from bpnet.datasets import StrandedProfile, SeqClassification
    if fasta_file is not None:
        if regions is None:
            raise ValueError(
                "fasta_file specified. Expecting regions to be specified as well"
            )
        dl_valid = SeqClassification(
            fasta_file=fasta_file,
            intervals_file=regions,
            incl_chromosomes=include_chr,
            excl_chromosomes=exclude_chr,
            auto_resize_len=seq_width,
        )
        chrom_sizes = _chrom_sizes(fasta_file)
    else:
        if dataspec is None:
            logger.info("Using dataspec used to train the model")
            # Specify dataspec
            dataspec = model_dir / "dataspec.yml"

        ds = DataSpec.load(dataspec)
        dl_valid = StrandedProfile(ds,
                                   incl_chromosomes=include_chr,
                                   excl_chromosomes=exclude_chr,
                                   intervals_file=regions,
                                   peak_width=peak_width,
                                   shuffle=False,
                                   seq_width=seq_width)
        chrom_sizes = _chrom_sizes(ds.fasta_file)

    # Setup contribution score trimming (not required currently)
    if seq_width > peak_width:
        # Trim
        # make sure we can nicely trim the peak
        logger.info("Trimming the output")
        assert (seq_width - peak_width) % 2 == 0
        trim_start = (seq_width - peak_width) // 2
        trim_end = seq_width - trim_start
        assert trim_end - trim_start == peak_width
    elif seq_width == peak_width:
        trim_start = 0
        trim_end = peak_width
    else:
        raise ValueError("seq_width < peak_width")

    seqmodel = SeqModel.from_mdir(model_dir)

    # get all possible interpretation names
    # make sure they match the specified glob
    intp_names = [
        name for name, _ in seqmodel.get_intp_tensors(preact_only=False)
        if fnmatch_any(name, contrib_wildcards)
    ]
    logger.info(f"Using the following interpretation targets:")
    for n in intp_names:
        print(n)

    if max_regions is not None:
        if len(dl_valid) > max_regions:
            logging.info(
                f"Using {max_regions} regions instead of the original {len(dl_valid)}"
            )
        else:
            logging.info(
                f"--max-regions={max_regions} is larger than the dataset size: {len(dl_valid)}. "
                "Using the dataset size for max-regions")
            max_regions = len(dl_valid)
    else:
        max_regions = len(dl_valid)

    max_batches = np.ceil(max_regions / batch_size)

    writer = HDF5BatchWriter(output_file, chunk_size=storage_chunk_size)
    for i, batch in enumerate(
            tqdm(dl_valid.batch_iter(batch_size=batch_size,
                                     shuffle=shuffle_regions,
                                     num_workers=num_workers),
                 total=max_batches)):
        # store the original batch containing 'inputs' and 'targets'
        if skip_bias:
            batch['inputs'] = {
                'seq': batch['inputs']['seq']
            }  # ignore all other inputs

        if max_batches > 0:
            if i > max_batches:
                break

        if shuffle_seq:
            # Di-nucleotide shuffle the sequences
            batch['inputs']['seq'] = onehot_dinucl_shuffle(
                batch['inputs']['seq'])

        for name in intp_names:
            hyp_contrib = seqmodel.contrib_score(
                batch['inputs']['seq'],
                name=name,
                method=method,
                batch_size=None)  # don't second-batch

            # put contribution scores to the dictionary
            # also trim the contribution scores appropriately so that
            # the output will always be w.r.t. the peak center
            batch[f"/hyp_contrib/{name}"] = hyp_contrib[:, trim_start:trim_end]

        # trim the sequence as well
        # Trim the sequence
        batch['inputs']['seq'] = batch['inputs']['seq'][:, trim_start:trim_end]

        # ? maybe it would it be better to have an explicit ContribFileWriter.
        # that way the written schema would be fixed
        writer.batch_write(batch)

    # add chromosome sizes
    writer.f.attrs['chrom_sizes'] = json.dumps(chrom_sizes)
    writer.close()
    logger.info(f"Done. Contribution score file was saved to: {output_file}")
Exemple #5
0
def bpnet_data_gw(dataspec,
                  intervals_file=None,
                  peak_width=200,
                  seq_width=None,
                  shuffle=True,
                  track_transform=None,
                  total_count_transform=lambda x: np.log(1 + x),
                  include_metadata=False,
                  include_classes=False,
                  tasks=None,
                  valid_chr=['chr2', 'chr3', 'chr4'],
                  test_chr=['chr1', 'chr8', 'chr9'],
                  exclude_chr=[]):
    """Genome-wide bpnet data
    """
    # NOTE = only chromosomes from chr1-22 and chrX and chrY are considered here
    # (e.g. all other chromosomes like ChrUn... are omitted)
    from bpnet.metrics import BPNetMetric, PeakPredictionProfileMetric, pearson_spearman
    # test and valid shouldn't be in the valid or test sets
    for vc in valid_chr:
        assert vc not in exclude_chr
    for vc in test_chr:
        assert vc not in exclude_chr

    dataspec = DataSpec.load(dataspec)

    # get the list of all chromosomes from the fasta file
    all_chr = _chrom_names(dataspec.fasta_file)

    if tasks is None:
        tasks = list(dataspec.task_specs)

    train = StrandedProfile(dataspec, peak_width,
                            seq_width=seq_width,
                            intervals_file=intervals_file,
                            intervals_format='bed3+labels',
                            include_metadata=include_metadata,
                            include_classes=include_classes,
                            tasks=tasks,
                            incl_chromosomes=[c for c in all_chr
                                              if c not in valid_chr + test_chr + exclude_chr],
                            excl_chromosomes=valid_chr + test_chr + exclude_chr,
                            shuffle=shuffle,
                            track_transform=track_transform,
                            total_count_transform=total_count_transform)

    valid = [('train-valid-genome-wide',
              StrandedProfile(dataspec, peak_width,
                              seq_width=seq_width,
                              intervals_file=intervals_file,
                              intervals_format='bed3+labels',
                              include_metadata=include_metadata,
                              include_classes=include_classes,
                              tasks=tasks,
                              incl_chromosomes=valid_chr,
                              shuffle=shuffle,
                              track_transform=track_transform,
                              total_count_transform=total_count_transform))]
    if include_classes:
        # Only use binary classification for genome-wide evaluation
        valid = valid + [('valid-genome-wide',
                          StrandedProfile(dataspec, peak_width,
                                          seq_width=seq_width,
                                          intervals_file=intervals_file,
                                          intervals_format='bed3+labels',
                                          include_metadata=include_metadata,
                                          include_classes=True,
                                          tasks=tasks,
                                          incl_chromosomes=valid_chr,
                                          shuffle=shuffle,
                                          track_transform=track_transform,
                                          total_count_transform=total_count_transform))]

    # Add also the peak regions
    valid = valid + [
        ('valid-peaks', StrandedProfile(dataspec, peak_width,
                                        seq_width=seq_width,
                                        intervals_file=None,
                                        intervals_format='bed3+labels',
                                        include_metadata=include_metadata,
                                        tasks=tasks,
                                        include_classes=False,  # dataspec doesn't contain labels
                                        incl_chromosomes=valid_chr,
                                        shuffle=shuffle,
                                        track_transform=track_transform,
                                        total_count_transform=total_count_transform)),
        ('train-peaks', StrandedProfile(dataspec, peak_width,
                                        seq_width=seq_width,
                                        intervals_file=None,
                                        intervals_format='bed3+labels',
                                        include_metadata=include_metadata,
                                        tasks=tasks,
                                        include_classes=False,  # dataspec doesn't contain labels
                                        incl_chromosomes=[c for c in all_chr
                                                          if c not in valid_chr + test_chr + exclude_chr],
                                        excl_chromosomes=valid_chr + test_chr + exclude_chr,
                                        shuffle=shuffle,
                                        track_transform=track_transform,
                                        total_count_transform=total_count_transform)),
        # use the default metric for the peak sets
    ]
    return train, valid
Exemple #6
0
def bpnet_data(dataspec,
               peak_width=1000,
               intervals_file=None,
               intervals_format='bed',
               seq_width=None,
               shuffle=True,
               total_count_transform=lambda x: np.log(1 + x),
               track_transform=None,
               include_metadata=False,
               valid_chr=['chr2', 'chr3', 'chr4'],
               test_chr=['chr1', 'chr8', 'chr9'],
               exclude_chr=[],
               augment_interval=True,
               interval_augmentation_shift=200,
               tasks=None):
    """BPNet default data-loader

    Args:
      tasks: specify a subset of the tasks to use in the dataspec.yml. If None, all tasks will be specified.
    """
    from bpnet.metrics import BPNetMetric, PeakPredictionProfileMetric, pearson_spearman
    # test and valid shouldn't be in the valid or test sets
    for vc in valid_chr:
        assert vc not in exclude_chr
    for vc in test_chr:
        assert vc not in exclude_chr

    dataspec = DataSpec.load(dataspec)

    if tasks is None:
        tasks = list(dataspec.task_specs)

    if augment_interval:
        interval_transformer = IntervalAugmentor(max_shift=interval_augmentation_shift,
                                                 flip_strand=True)
    else:
        interval_transformer = None

    # get the list of all chromosomes from the fasta file
    all_chr = _chrom_names(dataspec.fasta_file)

    return (StrandedProfile(dataspec, peak_width,
                            intervals_file=intervals_file,
                            intervals_format=intervals_format,
                            seq_width=seq_width,
                            include_metadata=include_metadata,
                            incl_chromosomes=[c for c in all_chr
                                              if c not in valid_chr + test_chr + exclude_chr],
                            excl_chromosomes=valid_chr + test_chr + exclude_chr,
                            tasks=tasks,
                            shuffle=shuffle,
                            track_transform=track_transform,
                            total_count_transform=total_count_transform,
                            interval_transformer=interval_transformer),
            [('valid-peaks', StrandedProfile(dataspec,
                                             peak_width,
                                             intervals_file=intervals_file,
                                             intervals_format=intervals_format,
                                             seq_width=seq_width,
                                             include_metadata=include_metadata,
                                             incl_chromosomes=valid_chr,
                                             tasks=tasks,
                                             interval_transformer=interval_transformer,
                                             shuffle=shuffle,
                                             track_transform=track_transform,
                                             total_count_transform=total_count_transform)),
             ('train-peaks', StrandedProfile(dataspec, peak_width,
                                             intervals_file=intervals_file,
                                             intervals_format=intervals_format,
                                             seq_width=seq_width,
                                             include_metadata=include_metadata,
                                             incl_chromosomes=[c for c in all_chr
                                                               if c not in valid_chr + test_chr + exclude_chr],
                                             excl_chromosomes=valid_chr + test_chr + exclude_chr,
                                             tasks=tasks,
                                             interval_transformer=interval_transformer,
                                             shuffle=shuffle,
                                             track_transform=track_transform,
                                             total_count_transform=total_count_transform)),
             ])
Exemple #7
0
    def __init__(self, ds,
                 peak_width=200,
                 seq_width=None,
                 incl_chromosomes=None,
                 excl_chromosomes=None,
                 intervals_file=None,
                 intervals_format='bed',
                 include_metadata=True,
                 tasks=None,
                 include_classes=False,
                 shuffle=True,
                 interval_transformer=None,
                 track_transform=None,
                 total_count_transform=lambda x: np.log(1 + x)):
        """Dataset for loading the bigwigs and fastas

        Args:
          ds (bpnet.dataspecs.DataSpec): data specification containing the
            fasta file, bed files and bigWig file paths
          chromosomes (list of str): a list of chor
          peak_width: resize the bed file to a certain width
          intervals_file: if specified, use these regions to train the model.
            If not specified, the regions are inferred from the dataspec.
          intervals_format: interval_file format. Available: bed, bed3, bed3+labels
          shuffle: True
          track_transform: function to be applied to transform the tracks (shape=(batch, seqlen, channels))
          total_count_transform: transform to apply to the total counts
            TODO - shall we standardize this to have also the inverse operation?
        """
        if isinstance(ds, str):
            self.ds = DataSpec.load(ds)
        else:
            self.ds = ds
        self.peak_width = peak_width
        if seq_width is None:
            self.seq_width = peak_width
        else:
            self.seq_width = seq_width

        assert intervals_format in ['bed3', 'bed3+labels', 'bed']

        self.shuffle = shuffle
        self.intervals_file = intervals_file
        self.intervals_format = intervals_format
        self.incl_chromosomes = incl_chromosomes
        self.excl_chromosomes = excl_chromosomes
        self.total_count_transform = total_count_transform
        self.track_transform = track_transform
        self.include_classes = include_classes
        # not specified yet
        self.fasta_extractor = None
        self.bw_extractors = None
        self.bias_bw_extractors = None
        self.include_metadata = include_metadata
        self.interval_transformer = interval_transformer

        # Load chromosome lengths
        self.chrom_lens = _chrom_sizes(self.ds.fasta_file)

        if self.intervals_file is None:
            # concatenate the bed files
            self.dfm = pd.concat([TsvReader(task_spec.peaks,
                                            num_chr=False,
                                            incl_chromosomes=incl_chromosomes,
                                            excl_chromosomes=excl_chromosomes,
                                            chromosome_lens=self.chrom_lens,
                                            resize_width=max(self.peak_width, self.seq_width)
                                            ).df.iloc[:, :3].assign(task=task)
                                  for task, task_spec in self.ds.task_specs.items()
                                  if task_spec.peaks is not None])
            assert list(self.dfm.columns)[:4] == [0, 1, 2, "task"]
            if self.shuffle:
                self.dfm = self.dfm.sample(frac=1)
            self.tsv = None
            self.dfm_tasks = None
        else:
            self.tsv = TsvReader(self.intervals_file,
                                 num_chr=False,
                                 # optional
                                 label_dtype=int if self.intervals_format == 'bed3+labels' else None,
                                 mask_ambigous=-1 if self.intervals_format == 'bed3+labels' else None,
                                 # --------------------------------------------
                                 incl_chromosomes=incl_chromosomes,
                                 excl_chromosomes=excl_chromosomes,
                                 chromosome_lens=self.chrom_lens,
                                 resize_width=max(self.peak_width, self.seq_width)
                                 )
            if self.shuffle:
                self.tsv.shuffle_inplace()
            self.dfm = self.tsv.df  # use the data-frame from tsv
            self.dfm_tasks = self.tsv.get_target_names()

        # remember the tasks
        if tasks is None:
            self.tasks = list(self.ds.task_specs)
        else:
            self.tasks = tasks

        if self.include_classes:
            assert self.dfm_tasks is not None

        if self.dfm_tasks is not None:
            assert set(self.tasks).issubset(self.dfm_tasks)

        # setup bias maps per task
        self.task_bias_tracks = {task: [bias for bias, spec in self.ds.bias_specs.items()
                                        if task in spec.tasks]
                                 for task in self.tasks}
Exemple #8
0
def bpnet_train(dataspec,
                output_dir,
                premade='bpnet9',
                config=None,
                override='',
                gpu=0,
                memfrac_gpu=0.45,
                num_workers=8,
                vmtouch=False,
                in_memory=False,
                wandb_project="",
                cometml_project="",
                run_id=None,
                note_params="",
                overwrite=False):
    """Train a model using gin-config

    Output files:
      train.log - log file
      model.h5 - Keras model HDF5 file
      seqmodel.pkl - Serialized SeqModel. This is the main trained model.
      eval-report.ipynb/.html - evaluation report containing training loss curves and some example model predictions.
        You can specify your own ipynb using `--override='report_template.name="my-template.ipynb"'`.
      model.gin -> copied from the input
      dataspec.yaml -> copied from the input
    """
    cometml_experiment, wandb_run, output_dir = start_experiment(
        output_dir=output_dir,
        cometml_project=cometml_project,
        wandb_project=wandb_project,
        run_id=run_id,
        note_params=note_params,
        overwrite=overwrite)
    # remember the executed command
    write_json(
        {
            "dataspec": dataspec,
            "output_dir": output_dir,
            "premade": premade,
            "config": config,
            "override": override,
            "gpu": gpu,
            "memfrac_gpu": memfrac_gpu,
            "num_workers": num_workers,
            "vmtouch": vmtouch,
            "in_memory": in_memory,
            "wandb_project": wandb_project,
            "cometml_project": cometml_project,
            "run_id": run_id,
            "note_params": note_params,
            "overwrite": overwrite
        },
        os.path.join(output_dir, 'bpnet-train.kwargs.json'),
        indent=2)

    # copy dataspec.yml and input config file over
    if config is not None:
        shutil.copyfile(config, os.path.join(output_dir, 'input-config.gin'))

    # parse and validate the dataspec
    ds = DataSpec.load(dataspec)
    related_dump_yaml(ds.abspath(), os.path.join(output_dir, 'dataspec.yml'))
    if vmtouch:
        if shutil.which('vmtouch') is None:
            logger.warn(
                "vmtouch is currently not installed. "
                "--vmtouch disabled. Please install vmtouch to enable it")
        else:
            # use vmtouch to load all file to memory
            ds.touch_all_files()

    # --------------------------------------------
    # Parse the config file
    # import gin.tf
    if gpu is not None:
        logger.info(f"Using gpu: {gpu}, memory fraction: {memfrac_gpu}")
        create_tf_session(gpu, per_process_gpu_memory_fraction=memfrac_gpu)

    gin_files = _get_gin_files(premade, config)

    # infer differnet hyper-parameters from the dataspec file
    if len(ds.bias_specs) > 0:
        use_bias = True
        if len(ds.bias_specs) > 1:
            # TODO - allow multiple bias track
            # - split the heads separately
            raise ValueError("Only a single bias track is currently supported")

        bias = [v for k, v in ds.bias_specs.items()][0]
        n_bias_tracks = len(bias.tracks)
    else:
        use_bias = False
        n_bias_tracks = 0
    tasks = list(ds.task_specs)
    # TODO - handle multiple track widths?
    tracks_per_task = [len(v.tracks) for k, v in ds.task_specs.items()][0]
    # figure out the right hyper-parameters
    dataspec_bindings = [
        f'dataspec="{dataspec}"', f'use_bias={use_bias}',
        f'n_bias_tracks={n_bias_tracks}', f'tracks_per_task={tracks_per_task}',
        f'tasks={tasks}'
    ]

    gin.parse_config_files_and_bindings(
        gin_files,
        bindings=dataspec_bindings + override.split(";"),
        # NOTE: custom files were inserted right after
        # ther user's config file and before the `override`
        # parameters specified at the command-line
        # This allows the user to disable the bias correction
        # despite being specified in the config file
        skip_unknown=False)

    # --------------------------------------------
    # Remember the parsed configs

    # comet - log environment
    if cometml_experiment is not None:
        # log other parameters
        cometml_experiment.log_parameters(dict(premade=premade,
                                               config=config,
                                               override=override,
                                               gin_files=gin_files,
                                               gpu=gpu),
                                          prefix='cli/')

    # wandb - log environment
    if wandb_run is not None:

        # store general configs
        wandb_run.config.update(
            dict_prefix_key(dict(premade=premade,
                                 config=config,
                                 override=override,
                                 gin_files=gin_files,
                                 gpu=gpu),
                            prefix='cli/'))

    return train(
        output_dir=output_dir,
        cometml_experiment=cometml_experiment,
        wandb_run=wandb_run,
        num_workers=num_workers,
        in_memory=in_memory,
        # to execute the sub-notebook
        memfrac_gpu=memfrac_gpu,
        gpu=gpu)