コード例 #1
0
ファイル: conv_asr.py プロジェクト: vadam5/NeMo
 def _prepare_for_export(self, **kwargs):
     m_count = 0
     for m in self.modules():
         if isinstance(m, MaskedConv1d):
             m.use_mask = False
             m_count += 1
     Exportable._prepare_for_export(self, **kwargs)
     logging.warning(f"Turned off {m_count} masked convolutions")
コード例 #2
0
ファイル: conv_asr.py プロジェクト: vadam5/NeMo
 def _prepare_for_export(self, **kwargs):
     m_count = 0
     for m in self.modules():
         if type(m).__name__ == "MaskedConv1d":
             m.use_mask = False
             m_count += 1
     if m_count > 0:
         logging.warning(f"Turned off {m_count} masked convolutions")
     Exportable._prepare_for_export(self, **kwargs)
コード例 #3
0
ファイル: conformer_encoder.py プロジェクト: AlexGrinch/NeMo
 def _prepare_for_export(self, **kwargs):
     # extend masks to configured maximum
     max_len = self.pos_emb_max_len
     if 'input_example' in kwargs:
         m_len = kwargs['input_example'][0].size(-1)
         if m_len > max_len:
             max_len = m_len
     logging.info(f"Extending input audio length to {max_len}")
     self.set_max_audio_length(max_len)
     Exportable._prepare_for_export(self, **kwargs)
コード例 #4
0
ファイル: conv_asr.py プロジェクト: stjordanis/NeMo
    def _prepare_for_export(self, **kwargs):
        m_count = 0
        for m in self.modules():
            if isinstance(m, MaskedConv1d):
                if self._rnnt_export:
                    pass
                else:
                    m.use_mask = False
                    m_count += 1
            if isinstance(m, SqueezeExcite):
                m._se_pool_step = m._se_pool_step_export

        Exportable._prepare_for_export(self, **kwargs)
        logging.warning(f"Turned off {m_count} masked convolutions")
コード例 #5
0
    def _prepare_for_export(self, **kwargs):
        m_count = 0
        stride = 1
        one_hour = 100 * 60 * 60 * 1  # 1 sec / 0.01 window stride = 100 frames / second * 60 sec * 60 min * 1 hour

        for name, m in self.named_modules():
            if isinstance(m, MaskedConv1d):
                m.use_mask = False
                m_count += 1

            if isinstance(m, MaskedConv1d):
                if m.conv.stride[0] > 1 and 'mconv' in name:
                    stride = stride * m.conv.stride[0]

            if isinstance(m, SqueezeExcite):
                m.set_max_len(int(one_hour // stride))  # One hour divided by current stride level

        Exportable._prepare_for_export(self, **kwargs)
        logging.warning(f"Turned off {m_count} masked convolutions")
コード例 #6
0
 def _prepare_for_export(self):
     Exportable._prepare_for_export(self)