Пример #1
0
    def _register_hook_handler(self, hook: Hook):
        module = eval(hook.module_path)
        if not isinstance(module, nn.Module):
            show(
                f"[UpstreamBase] - {hook.module_path} is not a valid nn.Module. Skip.",
                file=sys.stderr,
            )
            return

        if callable(hook.handler):
            show(
                f"[UpstreamBase] - Existing hook handler for {hook.unique_identifier} is found. Remove the existing one.",
                file=sys.stderr,
            )
            hook.handler.remove()

        def generate_hook_handler(hiddens: List, hook: Hook):
            def hook_handler(self, input, output):
                hiddens.append((hook.unique_identifier, hook.transform(input, output)))

            return hook_handler

        hook.handler = module.register_forward_hook(
            generate_hook_handler(self._hook_hiddens, hook)
        )
Пример #2
0
 def load_short_and_turn_float(*args, **kwargs):
     wav, sr = _load_wav(*args, **kwargs)
     if wav.dtype is not torch.short:
         show(
             f'[Warning] - Decoar only takes .wav files for the official usage.'
         )
         show(f'[Warning] - {args[0]} is not a .wav file')
     wav = wav.float()
     return wav, sr
Пример #3
0
    def _get_scheduler(self, optimizer):
        scheduler = get_scheduler(optimizer,
                                  self.config['runner']['total_steps'],
                                  self.config['scheduler'])

        init_scheduler = self.init_ckpt.get('Scheduler')
        if init_scheduler:
            show(
                '[Runner] - Loading scheduler weights from the previous experiment'
            )
            scheduler.load_state_dict(init_scheduler)
        return scheduler
Пример #4
0
    def _get_optimizer(self, model_params):
        optimizer = get_optimizer(model_params,
                                  self.config['runner']['total_steps'],
                                  self.config['optimizer'])

        init_optimizer = self.init_ckpt.get('Optimizer')
        if init_optimizer:
            show(
                '[Runner] - Loading optimizer weights from the previous experiment'
            )
            optimizer.load_state_dict(init_optimizer)
        return optimizer
Пример #5
0
    def _get_upstream(self):
        Upstream = getattr(importlib.import_module('hubconf'),
                           self.args.upstream)
        upstream_refresh = self.args.upstream_refresh

        if is_initialized() and get_rank() > 0:
            torch.distributed.barrier()
            upstream_refresh = False

        upstream = Upstream(
            feature_selection=self.args.upstream_feature_selection,
            model_config=self.args.upstream_model_config,
            refresh=upstream_refresh,
            ckpt=self.args.upstream_ckpt,
        ).to(self.args.device)

        if is_initialized() and get_rank() == 0:
            torch.distributed.barrier()

        interface_fn = ['get_output_dim', 'get_downsample_rate']
        for fn in interface_fn:
            assert hasattr(upstream, fn)

        if self.args.verbose:
            show(f'[Runner] - Upstream model architecture: {upstream}')
            show(
                f'[Runner] - Upstream has {count_parameters(upstream)} parameters'
            )
            show(
                f'[Runner] - Upstream output dimension: {upstream.get_output_dim()}'
            )
            downsample = upstream.get_downsample_rate()
            show(
                f'[Runner] - Upstream downsample rate: {downsample} ({downsample / SAMPLE_RATE * 1000} ms/frame)'
            )

        init_upstream = self.init_ckpt.get('Upstream')
        if init_upstream:
            show(
                '[Runner] - Loading upstream weights from the previous experiment'
            )
            upstream.load_state_dict(init_upstream)

        if is_initialized() and self.args.upstream_trainable:
            upstream = DDP(upstream,
                           device_ids=[self.args.local_rank],
                           find_unused_parameters=True)
            for fn in interface_fn:
                setattr(upstream, fn, getattr(upstream.module, fn))

        return upstream
Пример #6
0
    def _select_feature(self, features):
        feature = features.get(self.feature_selection)

        if isinstance(feature, dict):
            feature = list(feature.values())

        if len(feature) == 1:
            feature = feature[0]

        if feature is None:
            available_options = [key for key in features.keys() if key[0] != "_"]
            show(
                f"[{self.name}] - feature_selection = {self.feature_selection} is not supported for this upstream.",
                file=sys.stderr,
            )
            show(
                f"[{self.name}] - Supported options: {available_options}",
                file=sys.stderr,
            )
            raise ValueError
        return feature
Пример #7
0
    def _get_downstream(self):
        module_path = f'downstream.{self.args.downstream}.expert'
        Downstream = getattr(importlib.import_module(module_path),
                             'DownstreamExpert')
        downstream = Downstream(
            upstream_dim=self.upstream.get_output_dim(),
            upstream_rate=self.upstream.get_downsample_rate(),
            **self.config,
            **vars(self.args)).to(self.args.device)

        if self.args.verbose:
            show(f'[Runner] - Downstream model architecture: {downstream}')
            show(
                f'[Runner] - Downstream has {count_parameters(downstream)} parameters'
            )

        interface_fn = ['get_dataloader', 'log_records']
        for fn in interface_fn:
            assert hasattr(downstream, fn)

        init_downstream = self.init_ckpt.get('Downstream')
        if init_downstream:
            show(
                '[Runner] - Loading downstream weights from the previous experiment'
            )
            downstream.load_state_dict(init_downstream)

        if is_initialized():
            downstream = DDP(downstream,
                             device_ids=[self.args.local_rank],
                             find_unused_parameters=True)
            for fn in interface_fn:
                setattr(downstream, fn, getattr(downstream.module, fn))

        return downstream
Пример #8
0
    def __init__(
        self,
        upstream: UpstreamBase,
        feature_selection: str = "hidden_states",
        upstream_device: str = "cuda",
        **kwargs,
    ):
        super().__init__()
        self.feature_selection = feature_selection
        self.name = f"Featurizer for {upstream.__class__}"

        show(
            f"[{self.name}] - The input upstream is only for initialization and not saved in this nn.Module"
        )

        # This line is necessary as some models behave differently between train/eval
        # eg. The LayerDrop technique used in wav2vec2
        upstream.eval()

        paired_wavs = [torch.randn(SAMPLE_RATE).to(upstream_device)]
        paired_features = upstream(paired_wavs)

        feature = self._select_feature(paired_features)
        if isinstance(feature, (list, tuple)):
            self.layer_num = len(feature)
            show(
                f"[{self.name}] - Take a list of {self.layer_num} features and weighted sum them."
            )
            self.weights = nn.Parameter(torch.zeros(self.layer_num))
            feature = self._weighted_sum([f.cpu() for f in feature])
        else:
            feature = feature.cpu()

        self.output_dim = feature.size(-1)
        ratio = round(max(len(wav) for wav in paired_wavs) / feature.size(1))
        possible_rate = torch.LongTensor([160, 320])
        self.downsample_rate = int(
            possible_rate[(possible_rate - ratio).abs().argmin(dim=-1)]
        )
Пример #9
0
    def __call__(self, wavs: List[Tensor], *args, **kwargs):
        self._hook_hiddens.clear()

        result = super().__call__(wavs, *args, **kwargs) or {}
        assert isinstance(result, dict)

        if len(self._hook_hiddens) > 0:
            if (
                result.get("_hidden_states_info") is not None
                or result.get("hidden_states") is not None
                or result.get("last_hidden_state") is not None
            ):
                show(
                    "[UpstreamBase] - If there are registered hooks, '_hidden_states_info', 'hidden_states', and "
                    "'last_hidden_state' are reserved and should not be included in child class's return dict.",
                    file=sys.stderr,
                )
                raise ValueError

            hook_hiddens = self._hook_hiddens.copy()
            self._hook_hiddens.clear()

            if callable(self.hook_postprocess):
                hook_hiddens = self.hook_postprocess(hook_hiddens)

            result["_hidden_states_info"], result["hidden_states"] = zip(*hook_hiddens)
            result["last_hidden_state"] = result["hidden_states"][-1]

            for layer_id, hidden_state in enumerate(result["hidden_states"]):
                result[f"hidden_state_{layer_id}"] = hidden_state

            default = result.get("default")
            if default is not None:
                assert torch.allclose(default, result["last_hidden_state"])

        return result
Пример #10
0
 def _load_weight(self, model, name):
     init_weight = self.init_ckpt.get(name)
     if init_weight:
         show(f'[Runner] - Loading {name} weights from the previous experiment')
         model.load_state_dict(init_weight)