def __init__(self,
                 options,
                 inp_dim,
                 config=None,
                 mode='phone',
                 with_recognizer=True):
        super(DUAL_TRANSFORMER, self).__init__(options, inp_dim, config)

        del self.model_config
        self.model_config = DualTransformerConfig(self.config)
        self.out_dim = 0  # This attribute is necessary, for pytorch-kaldi and run_downstream.py
        self.mode = mode  # can be 'phone', 'speaker', or 'phone speaker'
        assert self.mode in 'phone speaker'

        # Build model
        if 'phone' in self.mode:
            self.PhoneticTransformer = TransformerPhoneticEncoder(
                self.model_config,
                self.inp_dim,
                with_recognizer=with_recognizer).to(self.device)
            self.PhoneticTransformer.eval(
            ) if self.no_grad else self.PhoneticTransformer.train()
            self.model = self.PhoneticTransformer

        if 'speaker' in self.mode:
            self.SpeakerTransformer = TransformerSpeakerEncoder(
                self.model_config,
                self.inp_dim,
                with_recognizer=with_recognizer).to(self.device)
            self.SpeakerTransformer.eval(
            ) if self.no_grad else self.SpeakerTransformer.train()
            self.model = self.SpeakerTransformer

        # Load from a PyTorch state_dict
        load = bool(strtobool(options["load_pretrain"]))
        if load and 'phone' in self.mode:
            self.PhoneticTransformer.Transformer = self.load_model(
                self.PhoneticTransformer.Transformer,
                self.all_states['PhoneticTransformer'])
            if hasattr(self.PhoneticTransformer, 'PhoneRecognizer'):
                self.PhoneticTransformer.PhoneRecognizer.load_state_dict(
                    self.all_states['PhoneticLayer'])
            self.out_dim += self.PhoneticTransformer.out_dim
            print('[Phonetic Transformer] - Number of parameters: ' + str(
                sum(p.numel() for p in self.PhoneticTransformer.parameters()
                    if p.requires_grad)))

        if load and 'speaker' in self.mode:
            self.SpeakerTransformer.Transformer = self.load_model(
                self.SpeakerTransformer.Transformer,
                self.all_states['SpeakerTransformer'])
            if hasattr(self.SpeakerTransformer, 'SpeakerRecognizer'):
                self.SpeakerTransformer.SpeakerRecognizer.load_state_dict(
                    self.all_states['SpeakerLayer'])
            self.out_dim += self.SpeakerTransformer.out_dim
            print('[Speaker Transformer] - Number of parameters: ' + str(
                sum(p.numel() for p in self.SpeakerTransformer.parameters()
                    if p.requires_grad)))
    def __init__(self, options, inp_dim, config=None, mode='phone'):
        super(DUAL_TRANSFORMER, self).__init__(options, inp_dim, config)

        del self.model_config
        self.model_config = DualTransformerConfig(self.config)
        self.out_dim = 0  # This attribute is for pytorch-kaldi
        self.mode = mode  # can be 'phone', 'speaker', or 'phone speaker'

        # Build model
        if 'phone' in self.mode:
            self.PhoneticTransformer = TransformerPhoneticEncoder(
                self.model_config, self.inp_dim).to(self.device)
            self.PhoneticTransformer.eval(
            ) if self.no_grad else self.PhoneticTransformer.train()

        if 'speaker' in self.mode:
            self.SpeakerTransformer = TransformerSpeakerEncoder(
                self.model_config, self.inp_dim).to(self.device)
            self.SpeakerTransformer.eval(
            ) if self.no_grad else self.SpeakerTransformer.train()

        # Load from a PyTorch state_dict
        load = bool(strtobool(options["load_pretrain"]))
        if load and 'phone' in self.mode:
            self.PhoneticTransformer.Transformer = self.load_model(
                self.PhoneticTransformer.Transformer,
                self.all_states['PhoneticTransformer'])
            self.PhoneticTransformer.PhoneRecognizer.load_state_dict(
                self.all_states['PhoneticLayer'])
            self.out_dim += self.model_config.phone_dim
            print('[Phonetic Transformer] - Number of parameters: ' + str(
                sum(p.numel() for p in self.PhoneticTransformer.parameters()
                    if p.requires_grad)))

        if load and 'speaker' in self.mode:
            self.SpeakerTransformer.Transformer = self.load_model(
                self.SpeakerTransformer.Transformer,
                self.all_states['SpeakerTransformer'])
            try:
                self.SpeakerTransformer.GlobalStyleToken.load_state_dict(
                    self.all_states['SpeakerLayer'])
            except:
                self.SpeakerTransformer.SpeakerRecognizer.load_state_dict(
                    self.all_states['SpeakerLayer'])
            self.out_dim += self.model_config.speaker_dim
            print('[Speaker Transformer] - Number of parameters: ' + str(
                sum(p.numel() for p in self.SpeakerTransformer.parameters()
                    if p.requires_grad)))