Ejemplo n.º 1
0
def config():

    model = {}
    VAE.get_config(
        dict(
            encoder={
                'cls': RecurrentEncoder,
                RecurrentEncoder: dict(
                    recurrent={'cls': LSTM}
                ),
            },
        ),
        model,
    )
    VAE.get_config(  # alternative dict update
        deflatten({
            ('encoder', 'cls'): RecurrentEncoder,
            ('encoder', RecurrentEncoder, 'recurrent', 'cls'): LSTM,
        }, sep=None),
        model,
    )
    VAE.get_config(  # second alternative update
        deflatten({
            'encoder/cls': 'RecurrentEncoder',
            'encoder/RecurrentEncoder/recurrent/cls': LSTM,
        }, sep='/'),
        model,
    )
Ejemplo n.º 2
0
def config():
    model_class = MaskEstimatorModel
    trainer_opts = deflatten({
        'model.factory': model_class,
        'optimizer.factory': Adam,
        'stop_trigger': (int(1e5), 'iteration'),
        'summary_trigger': (500, 'iteration'),
        'checkpoint_trigger': (500, 'iteration'),
        'storage_dir': None,
    })
    provider_opts = deflatten({
        'factory':
        SequenceProvider,
        'database.factory':
        Chime3,
        'audio_keys': [OBSERVATION, NOISE_IMAGE, SPEECH_IMAGE],
        'transform.factory':
        MaskTransformer,
        'transform.stft':
        dict(factory=STFT, shift=256, size=1024),
    })
    trainer_opts['model']['transformer'] = provider_opts['transform']

    storage_dir = None
    add_name = None
    if storage_dir is None:
        ex_name = get_experiment_name(trainer_opts['model'])
        if add_name is not None:
            ex_name += f'_{add_name}'
        observer = sacred.observers.FileStorageObserver.create(
            str(model_dir / ex_name))
        storage_dir = observer.basedir
    else:
        sacred.observers.FileStorageObserver.create(storage_dir)

    trainer_opts['storage_dir'] = storage_dir

    if (Path(storage_dir) / 'init.json').exists():
        trainer_opts, provider_opts = compare_configs(storage_dir,
                                                      trainer_opts,
                                                      provider_opts)

    Trainer.get_config(trainer_opts)
    Configurable.get_config(provider_opts)
    validate_checkpoint = 'ckpt_latest.pth'
    validation_kwargs = dict(
        metric='loss',
        maximize=False,
        max_checkpoints=1,
        validation_length=
        1000  # number of examples taken from the validation iterator
    )
Ejemplo n.º 3
0
def ArrayIntervalls_from_rttm(rttm_file, shape=None, sample_rate=16000):
    """
    >>> import tempfile
    >>> with tempfile.TemporaryDirectory() as tmpdir:
    ...     file = Path(tmpdir) / 'dummy.rttm'
    ...     file.write_text("SPEAKER S02 1 0 1 <NA> <NA> 1 <NA>\\nSPEAKER S02 1 2 1 <NA> <NA> 1 <NA>\\nSPEAKER S02 1 0 2 <NA> <NA> 2 <NA>")
    ...     print(file.read_text())
    ...     print(ArrayIntervalls_from_rttm(file))
    104
    SPEAKER S02 1 0 1 <NA> <NA> 1 <NA>
    SPEAKER S02 1 2 1 <NA> <NA> 1 <NA>
    SPEAKER S02 1 0 2 <NA> <NA> 2 <NA>
    {'S02': {'1': ArrayIntervall("0:16000, 32000:512016000", shape=None), '2': ArrayIntervall("0:32000", shape=None)}}
    """

    # Description for rttm files copied from kaldi chime6 receipt
    #    `steps/segmentation/convert_utt2spk_and_segments_to_rttm.py`:
    # <type> <file-id> <channel-id> <begin-time> \
    #         <duration> <ortho> <stype> <name> <conf>
    # <type> = SPEAKER for each segment.
    # <file-id> - the File-ID of the recording
    # <channel-id> - the Channel-ID, usually 1
    # <begin-time> - start time of segment
    # <duration> - duration of segment
    # <ortho> - <NA> (this is ignored)
    # <stype> - <NA> (this is ignored)
    # <name> - speaker name or id
    # <conf> - <NA> (this is ignored)
    from paderbox.utils.nested import deflatten
    import decimal

    rttm_file = Path(rttm_file)
    lines = rttm_file.read_text().splitlines()

    # ai = ArrayIntervall(shape)
    # SPEAKER S02_U06.ENH 1   40.60    3.22 <NA> <NA> P05 <NA>

    data = collections.defaultdict(lambda: ArrayIntervall(shape))

    for line in lines:
        parts = line.split()
        assert parts[0] == 'SPEAKER'
        file_id = parts[1]
        channel_id = parts[2]
        begin_time = decimal.Decimal(parts[3])
        duration_time = decimal.Decimal(parts[4])
        name = parts[7]

        end_time = (begin_time + duration_time) * sample_rate
        begin_time = begin_time * sample_rate

        assert begin_time == int(begin_time)
        assert end_time == int(end_time)

        data[(file_id, name)][int(begin_time):int(end_time)] = 1

    return deflatten(data, sep=None)
Ejemplo n.º 4
0
def config():

    model = {}
    Model.get_config(  # second alternative update
        deflatten({
            'transform.kwargs.sample_rate': 44100,
        }, sep='.'),
        model,
    )
Ejemplo n.º 5
0
    def __call__(self, example, random_onset=False):
        copies = flatten(
            {key: example[key] for key in self.copy_keys}
            if self.copy_keys is not None else example
        )

        if random_onset:
            start = np.random.rand()
            for fragment_step in self.fragment_steps.values():
                start = int(int(start*fragment_step) / fragment_step)
        else:
            start = 0.

        def fragment(key, x):
            fragment_step = self.fragment_steps[key]
            fragment_length = self.fragment_lengths[key]
            start_idx = int(start * fragment_step)
            if start_idx > 0:
                slc = [slice(None)] * len(x.shape)
                slc[self.axis] = slice(
                    int(start_idx), x.shape[self.axis]
                )
                x = x[slc]

            end_index = x.shape[self.axis]
            if self.drop_last:
                end_index -= (fragment_length - 1)
            fragments = list()
            for start_idx in np.arange(0, end_index, fragment_step):
                if fragment_length == 1 and self.squeeze:
                    fragments.append(x.take(start_idx, axis=self.axis))
                else:
                    slc = [slice(None)] * len(x.shape)
                    slc[self.axis] = slice(
                        int(start_idx), int(start_idx) + int(fragment_length)
                    )
                    fragments.append(x[tuple(slc)])
            return fragments

        features = flatten({
            key: nested_op(lambda x: fragment(key, x), example[key])
            for key in self.fragment_steps.keys()
        })
        num_fragments = np.array(
            [len(features[key]) for key in list(features.keys())]
        )
        assert all(num_fragments == num_fragments[0]), (list(features.keys()), num_fragments)
        fragments = list()
        for i in range(int(num_fragments[0])):
            fragment = deepcopy(copies)
            for key in features.keys():
                fragment[key] = features[key][i]
            fragment = deflatten(fragment)
            fragments.append(fragment)
        return fragments
Ejemplo n.º 6
0
 def get_signature(self):
     defaults = super().get_signature()
     defaults['transform'] = deflatten(
         {
             'cls': Compose,
             'kwargs.sample_rate': 8000,
             'kwargs.layer1.cls': Load,
             'kwargs.layer2.cls': FeatureExtractor,
         },
         sep='.')
     return defaults
Ejemplo n.º 7
0
    def load_checkpoint(
        self,
        checkpoint_path: (Path, str),
        in_checkpoint_path: str = 'model',
        map_location='cpu',
        consider_mpi=False,
    ) -> 'Module':
        """Update the module parameters from the given checkpoint.

        Args:
            checkpoint_path:
            in_checkpoint_path:
            map_location:
            consider_mpi:
                If True and mpi is used, only read config_path and
                checkpoint_path once and broadcast the content with mpi.
                Reduces the io load.

        Returns:


        """
        checkpoint_path = Path(checkpoint_path).expanduser().resolve()

        assert checkpoint_path.is_file(), checkpoint_path

        # Load weights
        if consider_mpi:
            import dlp_mpi
            if dlp_mpi.IS_MASTER:
                checkpoint_path_content = Path(checkpoint_path).read_bytes()
            else:
                checkpoint_path_content = None
            checkpoint_path_content = dlp_mpi.bcast(checkpoint_path_content)

            checkpoint = torch.load(
                io.BytesIO(checkpoint_path_content),
                map_location=map_location,
            )
        else:
            checkpoint = torch.load(checkpoint_path, map_location=map_location)

        if in_checkpoint_path:
            for part in in_checkpoint_path.split('.'):
                try:
                    checkpoint = deflatten(checkpoint, maxdepth=1)
                    checkpoint = checkpoint[part]
                except KeyError:
                    raise ValueError(part, in_checkpoint_path, checkpoint)
        self.load_state_dict(checkpoint)

        return self
Ejemplo n.º 8
0
    def __call__(self, example: dict, rng=np.random) -> List[dict]:
        """

        Args:
            example: dictionary with string keys
            rng: random number generator, maybe set using
                paderbox.utils.random_utils.str_to_random_state

        Returns:
        """

        example = flatten(example, sep=self.flatten_separator)

        to_segment_keys = self.get_to_segment_keys(example)
        axis = self.get_axis_list(to_segment_keys)

        to_segment = {
            key: example.pop(key) for key in to_segment_keys
        }

        if all([isinstance(key, str) for key in self.copy_keys]):
            to_copy = {key: example.pop(key) for key in self.copy_keys}
        elif self.copy_keys[0] is True:
            assert len(self.copy_keys) == 1, self.copy_keys
            to_copy = example
        elif self.copy_keys[0] is False:
            assert len(self.copy_keys) == 1, self.copy_keys
            to_copy = dict()
        else:
            raise TypeError('Unknown type for copy keys', self.copy_keys)

        if any([not isinstance(value, (np.ndarray, torch.Tensor))
                for value in to_segment.values()]):
            raise ValueError(
                'This segmenter only works on numpy arrays',
                'However, the following keys point to other types:',
                '\n'.join([f'{key} points to a {type(to_segment[key])}'
                           for key in to_segment_keys])
            )

        to_segment_lengths = [
            v.shape[axis[i]] for i, v in enumerate(to_segment.values())]

        assert to_segment_lengths[1:] == to_segment_lengths[:-1], (
            'The shapes along the segment dimension of all entries to segment'
            ' must be equal!\n'
            f'segment keys: {to_segment_keys}'
            f'to_segment_lengths: {to_segment_lengths}'
        )
        assert len(to_segment) > 0, ('Did not find any signals to segment',
                                     self.include, self.exclude, to_segment)
        to_segment_length = to_segment_lengths[0]

        # Discard examples that are shorter than `length`
        if not self.mode == 'max' and to_segment_length < self.length:
            import lazy_dataset
            raise lazy_dataset.FilterException()

        # Shortcut if segmentation is disabled
        if self.length == -1:
            to_copy.update(to_segment)
            to_copy.update(segment_start=0, segment_stop=to_segment_length)
            return [deflatten(to_copy)]

        boundaries, segmented = self.segment(to_segment, to_segment_length,
                                             axis=axis, rng=rng)

        segmented_examples = list()

        for idx, (start, stop) in enumerate(boundaries):
            example_copy = copy(to_copy)
            example_copy.update({key: value[idx]
                                 for key, value in segmented.items()})
            example_copy.update(segment_start=start, segment_stop=stop)
            segmented_examples.append(deflatten(example_copy))
        return segmented_examples
Ejemplo n.º 9
0
def train(
    _run,
    debug,
    data_provider,
    trainer,
    lr_rampup_steps,
    back_off_patience,
    lr_decay_step,
    lr_decay_factor,
    init_ckpt_path,
    frozen_cnn_2d_layers,
    frozen_cnn_1d_layers,
    resume,
    delay,
    validation_set_name,
    validation_ground_truth_filepath,
    weak_label_crnn_hyper_params_dir,
    eval_set_name,
    eval_ground_truth_filepath,
):
    print()
    print('##### Training #####')
    print()
    print_config(_run)
    assert (back_off_patience is None) or (lr_decay_step is None), (
        back_off_patience, lr_decay_step)
    if delay > 0:
        print(f'Sleep for {delay} seconds.')
        time.sleep(delay)

    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.train_transform.label_encoder.initialize_labels(
        dataset=data_provider.db.get_dataset(data_provider.validate_set),
        verbose=True)
    data_provider.test_transform.label_encoder.initialize_labels()
    trainer = Trainer.from_config(trainer)
    trainer.model.label_mapping = []
    for idx, label in sorted(data_provider.train_transform.label_encoder.
                             inverse_label_mapping.items()):
        assert idx == len(
            trainer.model.label_mapping), (idx, label,
                                           len(trainer.model.label_mapping))
        trainer.model.label_mapping.append(
            label.replace(', ', '__').replace(' ',
                                              '').replace('(', '_').replace(
                                                  ')', '_').replace("'", ''))
    print('Params', sum(p.numel() for p in trainer.model.parameters()))

    if init_ckpt_path is not None:
        print('Load init params')
        state_dict = deflatten(torch.load(init_ckpt_path,
                                          map_location='cpu')['model'],
                               maxdepth=1)
        trainer.model.cnn.load_state_dict(state_dict['cnn'])
    if frozen_cnn_2d_layers:
        print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers')
        trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers)
    if frozen_cnn_1d_layers:
        print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers')
        trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers)

    def add_tag_condition(example):
        example["tag_condition"] = example["weak_targets"]
        return example

    train_set = data_provider.get_train_set().map(add_tag_condition)
    validate_set = data_provider.get_validate_set().map(add_tag_condition)

    if validate_set is not None:
        trainer.test_run(train_set, validate_set)
        trainer.register_validation_hook(
            validate_set,
            metric='macro_fscore_strong',
            maximize=True,
        )

    breakpoints = []
    if lr_rampup_steps is not None:
        breakpoints += [(0, 0.), (lr_rampup_steps, 1.)]
    if lr_decay_step is not None:
        breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)]
    if len(breakpoints) > 0:
        if isinstance(trainer.optimizer, dict):
            names = sorted(trainer.optimizer.keys())
        else:
            names = [None]
        for name in names:
            trainer.register_hook(
                LRAnnealingHook(
                    trigger=AllTrigger(
                        (100, 'iteration'),
                        NotTrigger(
                            EndTrigger(breakpoints[-1][0] + 100, 'iteration')),
                    ),
                    breakpoints=breakpoints,
                    unit='iteration',
                    name=name,
                ))
    trainer.train(train_set, resume=resume)

    if validation_set_name:
        tuning.run(
            config_updates={
                'debug': debug,
                'weak_label_crnn_hyper_params_dir':
                weak_label_crnn_hyper_params_dir,
                'strong_label_crnn_dirs': [str(trainer.storage_dir)],
                'validation_set_name': validation_set_name,
                'validation_ground_truth_filepath':
                validation_ground_truth_filepath,
                'eval_set_name': eval_set_name,
                'eval_ground_truth_filepath': eval_ground_truth_filepath,
            })
Ejemplo n.º 10
0
    def __init__(
            self,
            model: 'pt.Model',
            storage_dir,
            optimizer,
            loss_weights=None,
            summary_trigger=(1, 'epoch'),
            checkpoint_trigger=(1, 'epoch'),
            stop_trigger=(1, 'epoch'),
            virtual_minibatch_size=1,
    ):
        """

        Args:
            model: a `padertorch.base.Model` object
            storage_dir: The structure of produced storage_dir is:
                .
                ├── checkpoints
                │   ├── ckpt_7122.pth
                │   ├── ckpt_14244.pth
                │   ├── ckpt_best_loss.pth -> ckpt_7122.pth
                │   ├── ckpt_latest.pth -> ckpt_14244.pth
                │   └── ckpt_ranking.json
                ├── events.out.tfevents.1548851867.ntsim5
            optimizer: a `padertorch.train.optimizer.Optimizer` object
                or dict of Optimizers
            loss_weights: dict of weights for model with multiple losses
            summary_trigger: `pytorch.train.trigger.IntervalTrigger` object
                or tuple describing the interval when summaries
                are written to event files.
                See padertorch.train.hooks.SummaryHook for a description of
                what a summary is.
            checkpoint_trigger: `padertorch.train.trigger.IntervalTrigger`
                object or tuple describing the interval when checkpoints
                are saved. See padertorch.train.hooks.CheckpointHook and
                padertorch.train.hooks.ValidationHook for a description of
                what happens on a checkpoint.
            stop_trigger: `padertorch.train.trigger.EndTrigger` object
                or tuple describing the endpoint of the training
            virtual_minibatch_size: Runs the optimisation in
                virtual_minibatch_size steps. By default run it after each
                review call.
                The advantage of a virtual_minibatch_size over addressing a
                minibatch dimension in forward and review is a lower memory
                footprint on cost of cpu time.
                Note: The gradients are accumulated and not averaged.
                Note: The virtual_minibatch_size is fixed and can contain data
                    from two epochs.


        Usage:

            # For test_run we recommend to do it without prefetch
            trainer = Trainer(...)  # or: Trainer.from_config(...)
            trainer.test_run(tr_ds, val_ds)
            trainer.train(tr_ds.prefetch(4, 8), val_ds_with.prefetch(4, 8))

        """
        if not isinstance(model, torch.nn.Module):
            raise TypeError(
                'Expect that the model is a subclass from padertorch.Module.\n'
                f'Got: type: {type(model)}\n{model}')
        self.model = model

        if isinstance(optimizer, dict):
            # Special case see Janek's example
            # TODO: Hint to example
            model_keys = set(deflatten(model.state_dict(), maxdepth=1).keys())
            assert model_keys == set(optimizer.keys()), (model_keys, optimizer)
            optimizer = optimizer.copy()
            for key, opti in list(optimizer.items()):
                if opti is None:
                    del optimizer[key]
                else:
                    assert isinstance(opti, Optimizer), opti
                    m = getattr(model, key)
                    opti.set_parameters(m.parameters())
        else:
            assert isinstance(optimizer, Optimizer), optimizer
            optimizer.set_parameters(model.parameters())

        self.optimizer = optimizer

        self.device = None  # Dummy value, will be set in Trainer.train

        self.storage_dir = Path(storage_dir).expanduser().resolve()
        self.writer = None
        self.train_timer = ContextTimerDict()
        self.validate_timer = ContextTimerDict()
        self.iteration = -1
        self.epoch = -1

        self.loss_weights = loss_weights
        self.virtual_minibatch_size = virtual_minibatch_size

        self.hooks = [
            SummaryHook(summary_trigger),
            CheckpointHook(checkpoint_trigger),
            StopTrainingHook(stop_trigger),
        ]
        self._stop_trigger = stop_trigger
        self._checkpoint_trigger = checkpoint_trigger
        self.writer_cls = tensorboardX.SummaryWriter
Ejemplo n.º 11
0
def train(
    _run,
    debug,
    data_provider,
    filter_desed_test_clips,
    trainer,
    lr_rampup_steps,
    back_off_patience,
    lr_decay_step,
    lr_decay_factor,
    init_ckpt_path,
    frozen_cnn_2d_layers,
    frozen_cnn_1d_layers,
    track_emissions,
    resume,
    delay,
    validation_set_name,
    validation_ground_truth_filepath,
    eval_set_name,
    eval_ground_truth_filepath,
):
    print()
    print('##### Training #####')
    print()
    print_config(_run)
    assert (back_off_patience is None) or (lr_decay_step is None), (
        back_off_patience, lr_decay_step)
    if delay > 0:
        print(f'Sleep for {delay} seconds.')
        time.sleep(delay)

    data_provider = DataProvider.from_config(data_provider)
    data_provider.train_transform.label_encoder.initialize_labels(
        dataset=data_provider.db.get_dataset(data_provider.validate_set),
        verbose=True)
    data_provider.test_transform.label_encoder.initialize_labels()
    trainer = Trainer.from_config(trainer)
    trainer.model.label_mapping = []
    for idx, label in sorted(data_provider.train_transform.label_encoder.
                             inverse_label_mapping.items()):
        assert idx == len(
            trainer.model.label_mapping), (idx, label,
                                           len(trainer.model.label_mapping))
        trainer.model.label_mapping.append(
            label.replace(', ', '__').replace(' ',
                                              '').replace('(', '_').replace(
                                                  ')', '_').replace("'", ''))
    print('Params', sum(p.numel() for p in trainer.model.parameters()))
    print('CNN Params', sum(p.numel() for p in trainer.model.cnn.parameters()))

    if init_ckpt_path is not None:
        print('Load init params')
        state_dict = deflatten(torch.load(init_ckpt_path,
                                          map_location='cpu')['model'],
                               maxdepth=2)
        trainer.model.cnn.load_state_dict(flatten(state_dict['cnn']))
        trainer.model.rnn_fwd.rnn.load_state_dict(state_dict['rnn_fwd']['rnn'])
        trainer.model.rnn_bwd.rnn.load_state_dict(state_dict['rnn_bwd']['rnn'])
        # pop output layer from checkpoint
        param_keys = sorted(state_dict['rnn_fwd']['output_net'].keys())
        layer_idx = [key.split('.')[1] for key in param_keys]
        last_layer_idx = layer_idx[-1]
        for key, layer_idx in zip(param_keys, layer_idx):
            if layer_idx == last_layer_idx:
                state_dict['rnn_fwd']['output_net'].pop(key)
                state_dict['rnn_bwd']['output_net'].pop(key)
        trainer.model.rnn_fwd.output_net.load_state_dict(
            state_dict['rnn_fwd']['output_net'], strict=False)
        trainer.model.rnn_bwd.output_net.load_state_dict(
            state_dict['rnn_bwd']['output_net'], strict=False)
    if frozen_cnn_2d_layers:
        print(f'Freeze {frozen_cnn_2d_layers} cnn_2d layers')
        trainer.model.cnn.cnn_2d.freeze(frozen_cnn_2d_layers)
    if frozen_cnn_1d_layers:
        print(f'Freeze {frozen_cnn_1d_layers} cnn_1d layers')
        trainer.model.cnn.cnn_1d.freeze(frozen_cnn_1d_layers)

    if filter_desed_test_clips:
        with (database_jsons_dir / 'desed.json').open() as fid:
            desed_json = json.load(fid)
        filter_example_ids = {
            clip_id.rsplit('_', maxsplit=2)[0][1:]
            for clip_id in (list(desed_json['datasets']['validation'].keys()) +
                            list(desed_json['datasets']['eval_public'].keys()))
        }
    else:
        filter_example_ids = None
    train_set = data_provider.get_train_set(
        filter_example_ids=filter_example_ids)
    validate_set = data_provider.get_validate_set()

    if validate_set is not None:
        trainer.test_run(train_set, validate_set)
        trainer.register_validation_hook(
            validate_set,
            metric='macro_fscore_weak',
            maximize=True,
            back_off_patience=back_off_patience,
            n_back_off=0 if back_off_patience is None else 1,
            lr_update_factor=lr_decay_factor,
            early_stopping_patience=back_off_patience,
        )

    breakpoints = []
    if lr_rampup_steps is not None:
        breakpoints += [(0, 0.), (lr_rampup_steps, 1.)]
    if lr_decay_step is not None:
        breakpoints += [(lr_decay_step, 1.), (lr_decay_step, lr_decay_factor)]
    if len(breakpoints) > 0:
        if isinstance(trainer.optimizer, dict):
            names = sorted(trainer.optimizer.keys())
        else:
            names = [None]
        for name in names:
            trainer.register_hook(
                LRAnnealingHook(
                    trigger=AllTrigger(
                        (100, 'iteration'),
                        NotTrigger(
                            EndTrigger(breakpoints[-1][0] + 100, 'iteration')),
                    ),
                    breakpoints=breakpoints,
                    unit='iteration',
                    name=name,
                ))
    trainer.train(train_set, resume=resume, track_emissions=track_emissions)

    if validation_set_name is not None:
        tuning.run(
            config_updates={
                'debug': debug,
                'crnn_dirs': [str(trainer.storage_dir)],
                'validation_set_name': validation_set_name,
                'validation_ground_truth_filepath':
                validation_ground_truth_filepath,
                'eval_set_name': eval_set_name,
                'eval_ground_truth_filepath': eval_ground_truth_filepath,
                'data_provider': {
                    'test_fetcher': {
                        'batch_size': data_provider.train_fetcher.batch_size,
                    }
                },
            })
Ejemplo n.º 12
0
    def from_config_and_checkpoint(
            cls,
            config_path: Path,
            checkpoint_path: Path,
            in_config_path: str = 'trainer.model',
            in_checkpoint_path: str = 'model',

            map_location='cpu',
            consider_mpi=False,
    ) -> 'Module':
        """Instantiate the module from given config and checkpoint.

        Args:
            config_path: 
            checkpoint_path: 
            in_config_path: 
            in_checkpoint_path: 
            map_location: 
            consider_mpi:
                If True and mpi is used, only read config_path and
                checkpoint_path once and broadcast the content with mpi.
                Reduces the io load.

        Returns:
        
        
        """
        config_path = Path(config_path).expanduser().resolve()
        checkpoint_path = Path(checkpoint_path).expanduser().resolve()

        assert config_path.is_file(), config_path
        assert checkpoint_path.is_file(), checkpoint_path
        # Load config
        module = cls.from_file(
            config_path,
            in_config_path,
            consider_mpi=False
        )

        # Load weights
        if consider_mpi:
            from paderbox.utils import mpi
            checkpoint_path_content = mpi.call_on_master_and_broadcast(
                Path(checkpoint_path).read_bytes,
            )
            checkpoint = torch.load(
                io.BytesIO(checkpoint_path_content),
                map_location=map_location,
            )
        else:
            checkpoint = torch.load(checkpoint_path, map_location=map_location)

        if in_checkpoint_path:
            for part in in_checkpoint_path.split('.'):
                try:
                    checkpoint = deflatten(checkpoint, maxdepth=1)
                    checkpoint = checkpoint[part]
                except KeyError:
                    raise ValueError(part, in_checkpoint_path, checkpoint)
        module.load_state_dict(checkpoint)

        return module