Exemplo n.º 1
0
    def __init__(self, args, save_path=None):

        super(ModelBase, self).__init__()

        self.save_path = save_path

        # for encoder, decoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru', 'conv_blstm', 'conv_bgru']:
            self.enc_n_units *= 2
        self.dec_type = args.dec_type

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # Feature extraction
        self.gaussian_noise = args.gaussian_noise
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.is_specaug = args.n_freq_masks > 0 or args.n_time_masks > 0
        self.specaug = None
        if self.is_specaug:
            assert args.n_stacks == 1 and args.n_skips == 1
            assert args.n_splices == 1
            self.specaug = SpecAugment(F=args.freq_width,
                                       T=args.time_width,
                                       n_freq_masks=args.n_freq_masks,
                                       n_time_masks=args.n_time_masks,
                                       p=args.time_width_upper)

        # Frontend
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0,
                                              param_init=args.param_init)

        # Encoder
        self.enc = build_encoder(args)
        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # main task
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Load the LM for LM fusion
            if args.lm_fusion and dir == 'fwd':
                lm_fusion = RNNLM(args.lm_conf)
                lm_fusion = load_checkpoint(lm_fusion, args.lm_fusion)[0]
            else:
                lm_fusion = None
                # TODO(hirofumi): for backward RNNLM

            # Load the LM for LM initialization
            if args.lm_init and dir == 'fwd':
                lm_init = RNNLM(args.lm_conf)
                lm_init = load_checkpoint(lm_init, args.lm_init)[0]
            else:
                lm_init = None
                # TODO(hirofumi): for backward RNNLM

            # Decoder
            if args.dec_type == 'transformer':
                dec = TransformerDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.transformer_attn_type,
                    attn_n_heads=args.transformer_attn_n_heads,
                    n_layers=args.dec_n_layers,
                    d_model=args.d_model,
                    d_ff=args.d_ff,
                    vocab=self.vocab,
                    tie_embedding=args.tie_embedding,
                    pe_type=args.pe_type,
                    layer_norm_eps=args.layer_norm_eps,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    lsm_prob=args.lsm_prob,
                    focal_loss_weight=args.focal_loss_weight,
                    focal_loss_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[int(fc) for fc in args.ctc_fc_list.split(
                        '_')] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [],
                    backward=(dir == 'bwd'),
                    global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch)
            elif 'transducer' in args.dec_type:
                dec = RNNTransducer(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    lsm_prob=args.lsm_prob,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[int(fc) for fc in args.ctc_fc_list.split(
                        '_')] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [],
                    lm_init=lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    param_init=args.param_init)
            else:
                dec = RNNDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_n_channels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_n_heads=args.attn_n_heads,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    loop_type=args.dec_loop_type,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    vocab=self.vocab,
                    tie_embedding=args.tie_embedding,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    zoneout=args.zoneout,
                    ss_prob=args.ss_prob,
                    ss_type=args.ss_type,
                    lsm_prob=args.lsm_prob,
                    focal_loss_weight=args.focal_loss_weight,
                    focal_loss_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[int(fc) for fc in args.ctc_fc_list.split(
                        '_')] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [],
                    input_feeding=args.input_feeding,
                    backward=(dir == 'bwd'),
                    lm_fusion=lm_fusion,
                    lm_fusion_type=args.lm_fusion_type,
                    discourse_aware=args.discourse_aware,
                    lm_init=lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    adaptive_softmax=args.adaptive_softmax,
                    param_init=args.param_init,
                    mocha_chunk_size=args.mocha_chunk_size,
                    replace_sos=args.replace_sos,
                    soft_label_weight=args.soft_label_weight)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                if args.dec_type == 'transformer':
                    raise NotImplementedError
                else:
                    dec_sub = RNNDecoder(
                        eos=self.eos,
                        unk=self.unk,
                        pad=self.pad,
                        blank=self.blank,
                        enc_n_units=self.enc_n_units,
                        attn_type=args.attn_type,
                        attn_dim=args.attn_dim,
                        attn_sharpening_factor=args.attn_sharpening,
                        attn_sigmoid_smoothing=args.attn_sigmoid,
                        attn_conv_out_channels=args.attn_conv_n_channels,
                        attn_conv_kernel_size=args.attn_conv_width,
                        attn_n_heads=1,
                        rnn_type=args.dec_type,
                        n_units=args.dec_n_units,
                        n_projs=args.dec_n_projs,
                        n_layers=args.dec_n_layers,
                        loop_type=args.dec_loop_type,
                        residual=args.dec_residual,
                        bottleneck_dim=args.dec_bottleneck_dim,
                        emb_dim=args.emb_dim,
                        tie_embedding=args.tie_embedding,
                        vocab=getattr(self, 'vocab_' + sub),
                        dropout=args.dropout_dec,
                        dropout_emb=args.dropout_emb,
                        dropout_att=args.dropout_att,
                        ss_prob=args.ss_prob,
                        ss_type=args.ss_type,
                        lsm_prob=args.lsm_prob,
                        focal_loss_weight=args.focal_loss_weight,
                        focal_loss_gamma=args.focal_loss_gamma,
                        ctc_weight=getattr(self, 'ctc_weight_' + sub),
                        ctc_lsm_prob=args.ctc_lsm_prob,
                        ctc_fc_list=[int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_')
                                     ] if getattr(args, 'ctc_fc_list_' + sub) is not None and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [],
                        input_feeding=args.input_feeding,
                        global_weight=getattr(self, sub + '_weight'),
                        mtl_per_batch=args.mtl_per_batch,
                        param_init=args.param_init)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed = dec.embed
            else:
                self.embed = Embedding(vocab=args.vocab_sub1,
                                       emb_dim=args.emb_dim,
                                       dropout=args.dropout_emb,
                                       ignore_index=self.pad)

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init, dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Fix all parameters except for the gating parts in deep fusion
        if args.lm_fusion_type == 'deep' and args.lm_fusion:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
Exemplo n.º 2
0
    def __init__(self, args, save_path=None, idx2token=None):

        super(ModelBase, self).__init__()

        self.save_path = save_path

        # for encoder, decoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru', 'conv_blstm', 'conv_bgru']:
            self.enc_n_units *= 2
        self.dec_type = args.dec_type

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # for MBR
        self.mbr_training = args.mbr_training
        self.recog_params = vars(args)
        self.idx2token = idx2token

        # Feature extraction
        self.gaussian_noise = args.gaussian_noise
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.use_specaug = args.n_freq_masks > 0 or args.n_time_masks > 0
        self.specaug = None
        self.flip_time_prob = args.flip_time_prob
        self.flip_freq_prob = args.flip_freq_prob
        self.weight_noise = args.weight_noise
        if self.use_specaug:
            assert args.n_stacks == 1 and args.n_skips == 1
            assert args.n_splices == 1
            self.specaug = SpecAugment(F=args.freq_width,
                                       T=args.time_width,
                                       n_freq_masks=args.n_freq_masks,
                                       n_time_masks=args.n_time_masks,
                                       p=args.time_width_upper)

        # Frontend
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0,
                                              param_init=args.param_init)

        # Encoder
        self.enc = build_encoder(args)
        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # main task
        external_lm = None
        directions = []
        if self.fwd_weight > 0 or (self.bwd_weight == 0
                                   and self.ctc_weight > 0):
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Load the LM for LM fusion and decoder initialization
            if args.external_lm and dir == 'fwd':
                external_lm = RNNLM(args.lm_conf)
                load_checkpoint(external_lm, args.external_lm)
                # freeze LM parameters
                for n, p in external_lm.named_parameters():
                    p.requires_grad = False

            # Decoder
            special_symbols = {
                'blank': self.blank,
                'unk': self.unk,
                'eos': self.eos,
                'pad': self.pad,
            }
            dec = build_decoder(
                args, special_symbols, self.enc.output_dim, args.vocab,
                self.ctc_weight, args.ctc_fc_list, self.main_weight -
                self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                external_lm)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                dec_sub = build_decoder(args, special_symbols,
                                        self.enc.output_dim,
                                        getattr(self, 'vocab_' + sub),
                                        getattr(self, 'ctc_weight_' + sub),
                                        getattr(args, 'ctc_fc_list_' + sub),
                                        getattr(self,
                                                sub + '_weight'), external_lm)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed = dec.embed
            else:
                self.embed = nn.Embedding(args.vocab_sub1,
                                          args.emb_dim,
                                          padding_idx=self.pad)
                self.dropout_emb = nn.Dropout(p=args.dropout_emb)

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Fix all parameters except for the gating parts in deep fusion
        if args.lm_fusion == 'deep' and external_lm is not None:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
Exemplo n.º 3
0
    def __init__(self, args, save_path=None, idx2token=None):

        super(ModelBase, self).__init__()

        self.save_path = save_path

        # for encoder, decoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.enc_type = args.enc_type
        self.dec_type = args.dec_type

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = args.total_weight - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # for MBR
        self.mbr_training = args.mbr_training
        self.recog_params = vars(args)
        self.idx2token = idx2token

        # for discourse-aware model
        self.utt_id_prev = None

        # Feature extraction
        self.input_noise_std = args.input_noise_std
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.weight_noise_std = args.weight_noise_std
        self.specaug = None
        if args.n_freq_masks > 0 or args.n_time_masks > 0:
            assert args.n_stacks == 1 and args.n_skips == 1
            assert args.n_splices == 1
            self.specaug = SpecAugment(
                F=args.freq_width,
                T=args.time_width,
                n_freq_masks=args.n_freq_masks,
                n_time_masks=args.n_time_masks,
                p=args.time_width_upper,
                adaptive_number_ratio=args.adaptive_number_ratio,
                adaptive_size_ratio=args.adaptive_size_ratio,
                max_n_time_masks=args.max_n_time_masks)

        # Frontend
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0,
                                              param_init=args.param_init)

        # Encoder
        self.enc = build_encoder(args)
        if args.freeze_encoder:
            for n, p in self.enc.named_parameters():
                if 'bridge' in n or 'sub1' in n:
                    continue
                p.requires_grad = False
                logger.info('freeze %s' % n)

        special_symbols = {
            'blank': self.blank,
            'unk': self.unk,
            'eos': self.eos,
            'pad': self.pad,
        }

        # main task
        external_lm = None
        directions = []
        if self.fwd_weight > 0 or (self.bwd_weight == 0
                                   and self.ctc_weight > 0):
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')

        for dir in directions:
            # Load the LM for LM fusion and decoder initialization
            if args.external_lm and dir == 'fwd':
                external_lm = RNNLM(args.lm_conf)
                load_checkpoint(args.external_lm, external_lm)
                # freeze LM parameters
                for n, p in external_lm.named_parameters():
                    p.requires_grad = False

            # Decoder
            dec = build_decoder(
                args, special_symbols, self.enc.output_dim, args.vocab,
                self.ctc_weight, self.main_weight -
                self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                external_lm)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                args_sub = copy.deepcopy(args)
                if hasattr(args, 'dec_config_' + sub):
                    for k, v in getattr(args, 'dec_config_' + sub).items():
                        setattr(args_sub, k, v)
                # NOTE: Other parameters are the same as the main decoder
                dec_sub = build_decoder(args_sub, special_symbols,
                                        getattr(self.enc, 'output_dim_' + sub),
                                        getattr(self, 'vocab_' + sub),
                                        getattr(self, 'ctc_weight_' + sub),
                                        getattr(self, sub + '_weight'),
                                        external_lm)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed = dec.embed
            else:
                self.embed = nn.Embedding(args.vocab_sub1,
                                          args.emb_dim,
                                          padding_idx=self.pad)
                self.dropout_emb = nn.Dropout(p=args.dropout_emb)

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Fix all parameters except for the gating parts in deep fusion
        if args.lm_fusion == 'deep' and external_lm is not None:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False