def reset_parameters(self, args):
     """Initialize parameters."""
     # initialize parameters
     initialize(self, args.transformer_init)
     if self.mt_weight > 0:
         torch.nn.init.normal_(self.encoder_mt.embed[0].weight, mean=0, std=args.adim ** -0.5)
         torch.nn.init.constant_(self.encoder_mt.embed[0].weight[self.pad], 0)
예제 #2
0
    def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0):
        # initialize parameters
        initialize(self, init_type)

        # initialize alpha in scaled positional encoding
        if self.use_scaled_pos_enc:
            self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
            self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
예제 #3
0
    def _reset_parameters(self, args):
        # initialize parameters
        initialize(self, args.transformer_init)

        # initialize alpha in scaled positional encoding
        if self.use_scaled_pos_enc:
            self.encoder.embed[-1].alpha.data = torch.tensor(
                args.initial_encoder_alpha)
            self.decoder.embed[-1].alpha.data = torch.tensor(
                args.initial_decoder_alpha)
예제 #4
0
 def reset_parameters(self, args):
     """Initialize parameters."""
     if args.pretrained_model:
         path = args.pretrained_model
         logging.warning("load pretrained asr model from {}".format(path))
         if 'snapshot' in path:
             model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)['model']
         else:
             model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
         self.load_state_dict(model_state_dict, strict=False)
         del model_state_dict
     else:
         initialize(self, args.transformer_init)
예제 #5
0
 def reset_parameters(self, args):
     """Initialize parameters."""
     # initialize parameters
     if args.pretrained_ctc_model:
         path = args.pretrained_ctc_model
         logging.warning("load pretrained asr model from {}".format(path))
         if 'snapshot' in path:
             model_state_dict = torch.load(
                 path, map_location=lambda storage, loc: storage)['model']
         else:
             model_state_dict = torch.load(
                 path, map_location=lambda storage, loc: storage)
         # In lid-classifier-mode, encoder+lid_lo params is utilized,
         # thus we do not remove any keys in model_state_dict and simply set strict=False
         self.load_state_dict(model_state_dict, strict=False)
         # logging.warning("load state dict params:")
         # logging.warning(model_state_dict.keys())
         del model_state_dict
     else:
         initialize(self, args.transformer_init)
예제 #6
0
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    if args.dtype != 'transformer':
        if args.etype == 'transformer':
            initialize(model.encoder, args.transformer_init)
            lecun_normal_init_parameters(model.dec)
        else:
            lecun_normal_init_parameters(model)

        model.dec.embed.weight.data.normal_(0, 1)

        for l in six.moves.range(len(model.dec.decoder)):
            set_forget_bias_to_one(model.dec.decoder[l].bias_ih)
    else:
        if args.etype == 'transformer':
            initialize(model, args.transformer_init)
        else:
            lecun_normal_init_parameters(model.encoder)
            initialize(model.decoder, args.transformer_init)
예제 #7
0
파일: initializer.py 프로젝트: xmpx/espnet
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    if "custom" not in args.dtype:
        if "custom" in args.etype:
            initialize(model.encoder, args.transformer_init)
            lecun_normal_init_parameters(model.dec)
        else:
            lecun_normal_init_parameters(model)

        model.dec.embed.weight.data.normal_(0, 1)

        for i in range(model.dec.dlayers):
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_ih_l0"))
            set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_hh_l0"))
    else:
        if "custom" in args.etype:
            initialize(model, args.transformer_init)
        else:
            lecun_normal_init_parameters(model.enc)
            initialize(model.decoder, args.transformer_init)
예제 #8
0
    def reset_parameters(self, args):
        """Initialize parameters."""

        # load state_dict, and keeps only encoder part
        # note that self.ctc.ctc_lo is also removed
        # prefix is added to meet the needs of moe structure
        def load_state_dict_encoder(path, prefix=''):
            if 'snapshot' in path:
                model_state_dict = torch.load(
                    path, map_location=lambda storage, loc: storage)['model']
            else:
                model_state_dict = torch.load(
                    path, map_location=lambda storage, loc: storage)
            for k in list(model_state_dict.keys()):
                if not 'encoder' in k:
                    # remove this key
                    del model_state_dict[k]
                else:
                    new_k = k.replace('encoder.', prefix + 'encoder.')
                    model_state_dict[new_k] = model_state_dict.pop(k)
            return model_state_dict

        # initialize parameters
        if args.pretrained_cn_ctc_model and args.pretrained_en_ctc_model:
            logging.warning(
                "loading pretrained ctc model for parallel encoder")
            # still need to initialize the 'other' params
            initialize(self, args.transformer_init)
            cn_state_dict = load_state_dict_encoder(
                args.pretrained_cn_ctc_model, prefix='cn_')
            self.load_state_dict(cn_state_dict, strict=False)
            del cn_state_dict
            en_state_dict = load_state_dict_encoder(
                args.pretrained_en_ctc_model, prefix='en_')
            self.load_state_dict(en_state_dict, strict=False)
            del en_state_dict
        else:
            initialize(self, args.transformer_init)
예제 #9
0
def initializer(model, args):
    """Initialize transducer model.

    Args:
        model (torch.nn.Module): transducer instance
        args (Namespace): argument Namespace containing options

    """
    if args.dtype != 'transformer':
        if args.etype == 'transformer':
            initialize(model.encoder, args.transformer_init)
            lecun_normal_init_parameters(model.decoder)
        else:
            lecun_normal_init_parameters(model)

        model.decoder.embed.weight.data.normal_(0, 1)
    else:
        if args.etype == 'transformer':
            initialize(model, args.transformer_init)
        else:
            lecun_normal_init_parameters(model.encoder)
            initialize(model.decoder, args.transformer_init)
 def reset_parameters(self, args):
     """Initialize parameters."""
     # initialize parameters
     initialize(self, args.transformer_init)
    def __init__(self,
                 num_time_mask=2,
                 num_freq_mask=2,
                 freq_mask_length=15,
                 time_mask_length=15,
                 feature_dim=320,
                 model_size=512,
                 feed_forward_size=1024,
                 hidden_size=64,
                 dropout=0.1,
                 num_head=8,
                 num_encoder_layer=6,
                 num_decoder_layer=6,
                 vocab_path='testing_vocab.model',
                 max_feature_length=1024,
                 max_token_length=50,
                 enable_spec_augment=True,
                 share_weight=True,
                 smoothing=0.1,
                 restrict_left_length=20,
                 restrict_right_length=20,
                 mtlalpha=0.2,
                 report_wer=True):
        super(Transformer, self).__init__()

        self.enable_spec_augment = enable_spec_augment
        self.max_token_length = max_token_length
        self.restrict_left_length = restrict_left_length
        self.restrict_right_length = restrict_right_length
        self.vocab = Vocab(vocab_path)
        self.sos = self.vocab.bos_id
        self.eos = self.vocab.eos_id
        self.adim = model_size
        self.odim = self.vocab.vocab_size
        self.ignore_id = self.vocab.pad_id

        if enable_spec_augment:
            self.spec_augment = SpecAugment(
                num_time_mask=num_time_mask,
                num_freq_mask=num_freq_mask,
                freq_mask_length=freq_mask_length,
                time_mask_length=time_mask_length,
                max_sequence_length=max_feature_length)

        self.encoder = Encoder(idim=feature_dim,
                               attention_dim=model_size,
                               attention_heads=num_head,
                               linear_units=feed_forward_size,
                               num_blocks=num_encoder_layer,
                               dropout_rate=dropout,
                               positional_dropout_rate=dropout,
                               attention_dropout_rate=dropout,
                               input_layer='linear',
                               padding_idx=self.vocab.pad_id)

        self.decoder = Decoder(odim=self.vocab.vocab_size,
                               attention_dim=model_size,
                               attention_heads=num_head,
                               linear_units=feed_forward_size,
                               num_blocks=num_decoder_layer,
                               dropout_rate=dropout,
                               positional_dropout_rate=dropout,
                               self_attention_dropout_rate=dropout,
                               src_attention_dropout_rate=0,
                               input_layer='embed',
                               use_output_layer=False)
        self.decoder_linear = t.nn.Linear(model_size,
                                          self.vocab.vocab_size,
                                          bias=True)
        self.decoder_switch_linear = t.nn.Linear(model_size, 4, bias=True)

        self.criterion = LabelSmoothingLoss(size=self.odim,
                                            smoothing=smoothing,
                                            padding_idx=self.vocab.pad_id,
                                            normalize_length=True)
        self.switch_criterion = LabelSmoothingLoss(
            size=4,
            smoothing=0,
            padding_idx=self.vocab.pad_id,
            normalize_length=True)
        self.mtlalpha = mtlalpha
        if mtlalpha > 0.0:
            self.ctc = CTC(self.odim,
                           eprojs=self.adim,
                           dropout_rate=dropout,
                           ctc_type='builtin',
                           reduce=False)
        else:
            self.ctc = None

        if report_wer:
            from espnet.nets.e2e_asr_common import ErrorCalculator

            def load_token_list(path=vocab_path.replace('.model', '.vocab')):
                with open(path) as reader:
                    data = reader.readlines()
                    data = [i.split('\t')[0] for i in data]
                return data

            self.char_list = load_token_list()
            self.error_calculator = ErrorCalculator(
                char_list=self.char_list,
                sym_space=' ',
                sym_blank=self.vocab.blank_token,
                report_wer=True)
        else:
            self.error_calculator = None
        self.rnnlm = None
        self.reporter = Reporter()

        self.switch_loss = LabelSmoothingLoss(size=4,
                                              smoothing=0,
                                              padding_idx=0)
        print('initing')
        initialize(self, init_type='xavier_normal')
        print('inited')