def test_constructor(self, asr_model):
     asr_model.train()
     # TODO: make proper config and assert correct number of weights
     # Check to/from config_dict:
     confdict = asr_model.to_config_dict()
     instance2 = EncDecCTCModelBPE.from_config_dict(confdict)
     assert isinstance(instance2, EncDecCTCModelBPE)
示例#2
0
    def __init__(
        self,
        asr_model,
        frame_len=1.6,
        total_buffer=4.0,
        batch_size=4,
    ):
        '''
        Args:
          frame_len: frame's duration, seconds
          frame_overlap: duration of overlaps before and after current frame, seconds
          offset: number of symbols to drop for smooth streaming
        '''
        self.frame_bufferer = FeatureFrameBufferer(asr_model=asr_model,
                                                   frame_len=frame_len,
                                                   batch_size=batch_size,
                                                   total_buffer=total_buffer)

        self.asr_model = asr_model

        self.batch_size = batch_size
        self.all_logits = []
        self.all_preds = []

        self.unmerged = []

        if hasattr(asr_model.decoder, "vocabulary"):
            self.blank_id = len(asr_model.decoder.vocabulary)
        else:
            self.blank_id = len(asr_model.joint.vocabulary)
        self.tokenizer = asr_model.tokenizer
        self.toks_unmerged = []
        self.frame_buffers = []
        self.reset()
        cfg = copy.deepcopy(asr_model._cfg)
        self.frame_len = frame_len
        OmegaConf.set_struct(cfg.preprocessor, False)

        # some changes for streaming scenario
        cfg.preprocessor.dither = 0.0
        cfg.preprocessor.pad_to = 0
        cfg.preprocessor.normalize = "None"
        self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(
            cfg.preprocessor)
        self.raw_preprocessor.to(asr_model.device)