Exemple #1
0
def test_frontend_output_size():
    # Skip some testing cases
    if not is_torch_1_7_plus:
        return

    s3prl_path = None
    python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
    for p in python_path_list:
        if p.endswith("s3prl"):
            s3prl_path = p
            break
    assert s3prl_path is not None

    s3prl_upstream = torch.hub.load(
        s3prl_path,
        "mel",
        source="local",
    ).to("cpu")

    feature_selection = "last_hidden_state"
    s3prl_featurizer = Featurizer(
        upstream=s3prl_upstream,
        feature_selection=feature_selection,
        upstream_device="cpu",
    )

    wavs = [torch.randn(1600)]
    feats = s3prl_upstream(wavs)
    feats = s3prl_featurizer(wavs, feats)
    assert feats[0].shape[-1] == 80
Exemple #2
0
    def _get_featurizer(self):
        model = Featurizer(
            self.upstream.model, self.args.upstream_feature_selection,
            upstream_device=self.args.device,
        ).to(self.args.device)

        return self._init_model(
            model = model,
            name = 'Featurizer',
            trainable = True,
            interfaces = ['output_dim', 'downsample_rate']
        )
Exemple #3
0
    def __init__(self, ckpt, **kwargs):
        super(UpstreamExpert, self).__init__()
        ckpt = torch.load(ckpt, map_location='cpu')

        args = ckpt['Args']
        self.upstream = getattr(s3prl.hub, args.upstream)()
        self.featurizer = Featurizer(self.upstream, "last_hidden_state", "cpu")

        config = ckpt['Config']
        modelrc = config['downstream_expert']['modelrc']
        model_cls = eval(modelrc['select'])
        model_conf = modelrc[modelrc['select']]
        self.model = model_cls(self.featurizer.output_dim,
                               output_class_num=TIMIT_PHONE_CLASSES,
                               **model_conf)
        self.model.load_state_dict(
            UpstreamExpert._fix_state_key(ckpt['Downstream']))
Exemple #4
0
    def _get_upstream(self, frontend_conf):
        """Get S3PRL upstream model."""
        s3prl_args = base_s3prl_setup(
            Namespace(**frontend_conf, device="cpu"),
        )
        self.args = s3prl_args

        s3prl_path = None
        python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
        for p in python_path_list:
            if p.endswith("s3prl"):
                s3prl_path = p
                break
        assert s3prl_path is not None

        s3prl_upstream = torch.hub.load(
            s3prl_path,
            s3prl_args.upstream,
            ckpt=s3prl_args.upstream_ckpt,
            model_config=s3prl_args.upstream_model_config,
            refresh=s3prl_args.upstream_refresh,
            source="local",
        ).to("cpu")

        if getattr(
            s3prl_upstream, "model", None
        ) is not None and s3prl_upstream.model.__class__.__name__ in [
            "Wav2Vec2Model",
            "HubertModel",
        ]:
            s3prl_upstream.model.encoder.layerdrop = 0.0

        from s3prl.upstream.interfaces import Featurizer

        if self.multilayer_feature is None:
            feature_selection = "last_hidden_state"
        else:
            feature_selection = "hidden_states"
        s3prl_featurizer = Featurizer(
            upstream=s3prl_upstream,
            feature_selection=feature_selection,
            upstream_device="cpu",
        )

        return s3prl_upstream, s3prl_featurizer