예제 #1
0
class DUAL_TRANSFORMER(TransformerBaseWrapper):
    def __init__(self, options, inp_dim, config=None, mode='phone', with_recognizer=True):
        super(DUAL_TRANSFORMER, self).__init__(options, inp_dim, config)

        del self.model_config
        self.model_config = DualTransformerConfig(self.config)
        self.out_dim = 0 # This attribute is necessary, for pytorch-kaldi and run_downstream.py
        self.mode = mode # can be 'phone', 'speaker', or 'phone speaker'
        assert self.mode in 'phone speaker'
        
        # Build model
        if 'phone' in self.mode:
            self.PhoneticTransformer = TransformerPhoneticEncoder(self.model_config, self.inp_dim, with_recognizer=with_recognizer).to(self.device)
            self.PhoneticTransformer.eval() if self.no_grad else self.PhoneticTransformer.train()
            self.model = self.PhoneticTransformer
        
        if 'speaker' in self.mode:
            self.SpeakerTransformer = TransformerSpeakerEncoder(self.model_config, self.inp_dim, with_recognizer=with_recognizer).to(self.device)
            self.SpeakerTransformer.eval() if self.no_grad else self.SpeakerTransformer.train()
            self.model = self.SpeakerTransformer
        
        # Load from a PyTorch state_dict
        load = bool(strtobool(options["load_pretrain"]))
        if load and 'phone' in self.mode:
            self.PhoneticTransformer.Transformer = self.load_model(self.PhoneticTransformer.Transformer, 
                                                                   self.all_states['PhoneticTransformer'])
            if hasattr(self.PhoneticTransformer, 'PhoneRecognizer'): 
                self.PhoneticTransformer.PhoneRecognizer.load_state_dict(self.all_states['PhoneticLayer'])
            self.out_dim += self.PhoneticTransformer.out_dim
            print('[Phonetic Transformer] - Number of parameters: ' + str(sum(p.numel() for p in self.PhoneticTransformer.parameters() if p.requires_grad)))

        if load and 'speaker' in self.mode:
            self.SpeakerTransformer.Transformer = self.load_model(self.SpeakerTransformer.Transformer, 
                                                                   self.all_states['SpeakerTransformer'])
            if hasattr(self.SpeakerTransformer, 'SpeakerRecognizer'):
                self.SpeakerTransformer.SpeakerRecognizer.load_state_dict(self.all_states['SpeakerLayer'])
            self.out_dim += self.SpeakerTransformer.out_dim
            print('[Speaker Transformer] - Number of parameters: ' + str(sum(p.numel() for p in self.SpeakerTransformer.parameters() if p.requires_grad)))


    def _dual_forward(self, x):
        if hasattr(self, 'preprocessor'):
            x = self.preprocessor(x.transpose(1, 2).contiguous())[0]
        if 'phone' in self.mode and 'speaker' in self.mode:
            self.model = self.PhoneticTransformer
            phonetic_code = self._forward(copy.deepcopy(x))
            self.model = self.SpeakerTransformer
            speaker_code = self._forward(x)
            if self.model_config.average_pooling: 
                speaker_code = speaker_code.repeat(1, phonetic_code.size(1), 1)
            if self.model_config.combine == 'concat':
                x = torch.cat((phonetic_code, speaker_code), dim=2)
            elif self.model_config.combine == 'add':
                x = phonetic_code + speaker_code
            else:
                raise NotImplementedError

        elif ('phone' in self.mode) != ('speaker' in self.mode): # exclusive or
            x = self._forward(x)
        else:
            raise NotImplementedError
        return x


    def forward(self, x):
        if self.no_grad:
            with torch.no_grad():
                x = self._dual_forward(x)
        else:
            x = self._dual_forward(x)
        return x
class DUAL_TRANSFORMER(TransformerBaseWrapper):
    def __init__(self, options, inp_dim, config=None, mode='phone'):
        super(DUAL_TRANSFORMER, self).__init__(options, inp_dim, config)

        del self.model_config
        self.model_config = DualTransformerConfig(self.config)
        self.out_dim = 0  # This attribute is for pytorch-kaldi
        self.mode = mode  # can be 'phone', 'speaker', or 'phone speaker'

        # Build model
        if 'phone' in self.mode:
            self.PhoneticTransformer = TransformerPhoneticEncoder(
                self.model_config, self.inp_dim).to(self.device)
            self.PhoneticTransformer.eval(
            ) if self.no_grad else self.PhoneticTransformer.train()

        if 'speaker' in self.mode:
            self.SpeakerTransformer = TransformerSpeakerEncoder(
                self.model_config, self.inp_dim).to(self.device)
            self.SpeakerTransformer.eval(
            ) if self.no_grad else self.SpeakerTransformer.train()

        # Load from a PyTorch state_dict
        load = bool(strtobool(options["load_pretrain"]))
        if load and 'phone' in self.mode:
            self.PhoneticTransformer.Transformer = self.load_model(
                self.PhoneticTransformer.Transformer,
                self.all_states['PhoneticTransformer'])
            self.PhoneticTransformer.PhoneRecognizer.load_state_dict(
                self.all_states['PhoneticLayer'])
            self.out_dim += self.model_config.phone_dim
            print('[Phonetic Transformer] - Number of parameters: ' + str(
                sum(p.numel() for p in self.PhoneticTransformer.parameters()
                    if p.requires_grad)))

        if load and 'speaker' in self.mode:
            self.SpeakerTransformer.Transformer = self.load_model(
                self.SpeakerTransformer.Transformer,
                self.all_states['SpeakerTransformer'])
            try:
                self.SpeakerTransformer.GlobalStyleToken.load_state_dict(
                    self.all_states['SpeakerLayer'])
            except:
                self.SpeakerTransformer.SpeakerRecognizer.load_state_dict(
                    self.all_states['SpeakerLayer'])
            self.out_dim += self.model_config.speaker_dim
            print('[Speaker Transformer] - Number of parameters: ' + str(
                sum(p.numel() for p in self.SpeakerTransformer.parameters()
                    if p.requires_grad)))

    def _dual_forward(self, x):
        if 'phone' in self.mode and 'speaker' in self.mode:
            self.model = self.PhoneticTransformer
            phonetic_code = self._forward(x)
            self.model = self.SpeakerTransformer
            speaker_code = self._forward(x)
            speaker_code = speaker_code.repeat(1, phonetic_code.size(1), 1)
            if self.model_config.combine == 'concat':
                x = torch.cat((phonetic_code, speaker_code), dim=2)
            elif self.model_config.combine == 'add':
                x = phonetic_code + speaker_code
            else:
                raise NotImplementedError

        elif 'phone' in self.mode:
            self.model = self.PhoneticTransformer
            x = self._forward(x)
        elif 'speaker' in self.mode:
            self.model = self.SpeakerTransformer
            x = self._forward(x)
        else:
            raise NotImplementedError
        return x

    def forward(self, x):
        if self.no_grad:
            with torch.no_grad():
                x = self._dual_forward(x)
        else:
            x = self._dual_forward(x)
        return x