示例#1
0
def build_encoder(args):

    if args.enc_type == 'tds':
        from neural_sp.models.seq2seq.encoders.tds import TDSEncoder
        raise ValueError
        encoder = TDSEncoder(
            input_dim=args.input_dim * args.n_stacks,
            in_channel=args.conv_in_channel,
            channels=args.conv_channels,
            kernel_sizes=args.conv_kernel_sizes,
            dropout=args.dropout_enc,
            bottleneck_dim=args.transformer_d_model
            if 'transformer' in args.dec_type else args.dec_n_units)

    elif args.enc_type == 'gated_conv':
        from neural_sp.models.seq2seq.encoders.gated_conv import GatedConvEncoder
        raise ValueError
        encoder = GatedConvEncoder(
            input_dim=args.input_dim * args.n_stacks,
            in_channel=args.conv_in_channel,
            channels=args.conv_channels,
            kernel_sizes=args.conv_kernel_sizes,
            dropout=args.dropout_enc,
            bottleneck_dim=args.transformer_d_model
            if 'transformer' in args.dec_type else args.dec_n_units,
            param_init=args.param_init)

    elif 'transformer' in args.enc_type:
        from neural_sp.models.seq2seq.encoders.transformer import TransformerEncoder
        encoder = TransformerEncoder(
            input_dim=args.input_dim
            if args.input_type == 'speech' else args.emb_dim,
            enc_type=args.enc_type,
            attn_type=args.transformer_attn_type,
            n_heads=args.transformer_n_heads,
            n_layers=args.enc_n_layers,
            n_layers_sub1=args.enc_n_layers_sub1,
            n_layers_sub2=args.enc_n_layers_sub2,
            d_model=args.transformer_d_model,
            d_ff=args.transformer_d_ff,
            last_proj_dim=args.transformer_d_model
            if 'transformer' in args.dec_type else 0,
            pe_type=args.transformer_enc_pe_type,
            layer_norm_eps=args.transformer_layer_norm_eps,
            ffn_activation=args.transformer_ffn_activation,
            dropout_in=args.dropout_in,
            dropout=args.dropout_enc,
            dropout_att=args.dropout_att,
            dropout_layer=args.dropout_enc_layer,
            n_stacks=args.n_stacks,
            n_splices=args.n_splices,
            conv_in_channel=args.conv_in_channel,
            conv_channels=args.conv_channels,
            conv_kernel_sizes=args.conv_kernel_sizes,
            conv_strides=args.conv_strides,
            conv_poolings=args.conv_poolings,
            conv_batch_norm=args.conv_batch_norm,
            conv_layer_norm=args.conv_layer_norm,
            conv_bottleneck_dim=args.conv_bottleneck_dim,
            conv_param_init=args.param_init,
            task_specific_layer=args.task_specific_layer,
            param_init=args.transformer_param_init,
            chunk_size_left=args.lc_chunk_size_left,
            chunk_size_current=args.lc_chunk_size_current,
            chunk_size_right=args.lc_chunk_size_right)

    else:
        subsample = [1] * args.enc_n_layers
        for l, s in enumerate(
                list(map(int,
                         args.subsample.split('_')[:args.enc_n_layers]))):
            subsample[l] = s

        from neural_sp.models.seq2seq.encoders.rnn import RNNEncoder
        encoder = RNNEncoder(
            input_dim=args.input_dim
            if args.input_type == 'speech' else args.emb_dim,
            rnn_type=args.enc_type,
            n_units=args.enc_n_units,
            n_projs=args.enc_n_projs,
            last_proj_dim=args.transformer_d_model
            if 'transformer' in args.dec_type else 0,
            n_layers=args.enc_n_layers,
            n_layers_sub1=args.enc_n_layers_sub1,
            n_layers_sub2=args.enc_n_layers_sub2,
            dropout_in=args.dropout_in,
            dropout=args.dropout_enc,
            subsample=subsample,
            subsample_type=args.subsample_type,
            n_stacks=args.n_stacks,
            n_splices=args.n_splices,
            conv_in_channel=args.conv_in_channel,
            conv_channels=args.conv_channels,
            conv_kernel_sizes=args.conv_kernel_sizes,
            conv_strides=args.conv_strides,
            conv_poolings=args.conv_poolings,
            conv_batch_norm=args.conv_batch_norm,
            conv_layer_norm=args.conv_layer_norm,
            conv_bottleneck_dim=args.conv_bottleneck_dim,
            bidirectional_sum_fwd_bwd=args.bidirectional_sum_fwd_bwd,
            task_specific_layer=args.task_specific_layer,
            param_init=args.param_init,
            chunk_size_left=args.lc_chunk_size_left,
            chunk_size_right=args.lc_chunk_size_right)
        # NOTE: pure Conv/TDS/GatedConv encoders are also included

    return encoder
示例#2
0
def build_encoder(args):

    # safeguard
    if not hasattr(args, 'transformer_enc_d_model') and hasattr(args, 'transformer_d_model'):
        args.transformer_enc_d_model = args.transformer_d_model
        args.transformer_dec_d_model = args.transformer_d_model
    if not hasattr(args, 'transformer_enc_d_ff') and hasattr(args, 'transformer_d_ff'):
        args.transformer_enc_d_ff = args.transformer_d_ff
    if not hasattr(args, 'transformer_enc_n_heads') and hasattr(args, 'transformer_n_heads'):
        args.transformer_enc_n_heads = args.transformer_n_heads

    if args.enc_type == 'tds':
        from neural_sp.models.seq2seq.encoders.tds import TDSEncoder
        encoder = TDSEncoder(
            input_dim=args.input_dim * args.n_stacks,
            in_channel=args.conv_in_channel,
            channels=args.conv_channels,
            kernel_sizes=args.conv_kernel_sizes,
            dropout=args.dropout_enc,
            last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else args.dec_n_units)

    elif args.enc_type == 'gated_conv':
        from neural_sp.models.seq2seq.encoders.gated_conv import GatedConvEncoder
        raise ValueError
        encoder = GatedConvEncoder(
            input_dim=args.input_dim * args.n_stacks,
            in_channel=args.conv_in_channel,
            channels=args.conv_channels,
            kernel_sizes=args.conv_kernel_sizes,
            dropout=args.dropout_enc,
            last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else args.dec_n_units,
            param_init=args.param_init)

    elif 'transformer' in args.enc_type:
        from neural_sp.models.seq2seq.encoders.transformer import TransformerEncoder
        encoder = TransformerEncoder(
            input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim,
            enc_type=args.enc_type,
            n_heads=args.transformer_enc_n_heads,
            n_layers=args.enc_n_layers,
            n_layers_sub1=args.enc_n_layers_sub1,
            n_layers_sub2=args.enc_n_layers_sub2,
            d_model=args.transformer_enc_d_model,
            d_ff=args.transformer_enc_d_ff,
            ffn_bottleneck_dim=args.transformer_ffn_bottleneck_dim,
            ffn_activation=args.transformer_ffn_activation,
            pe_type=args.transformer_enc_pe_type,
            layer_norm_eps=args.transformer_layer_norm_eps,
            last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0,
            dropout_in=args.dropout_in,
            dropout=args.dropout_enc,
            dropout_att=args.dropout_att,
            dropout_layer=args.dropout_enc_layer,
            subsample=args.subsample,
            subsample_type=args.subsample_type,
            n_stacks=args.n_stacks,
            n_splices=args.n_splices,
            conv_in_channel=args.conv_in_channel,
            conv_channels=args.conv_channels,
            conv_kernel_sizes=args.conv_kernel_sizes,
            conv_strides=args.conv_strides,
            conv_poolings=args.conv_poolings,
            conv_batch_norm=args.conv_batch_norm,
            conv_layer_norm=args.conv_layer_norm,
            conv_bottleneck_dim=args.conv_bottleneck_dim,
            conv_param_init=args.param_init,
            task_specific_layer=args.task_specific_layer,
            param_init=args.transformer_param_init,
            clamp_len=args.transformer_enc_clamp_len,
            lookahead=args.transformer_enc_lookaheads,
            chunk_size_left=args.lc_chunk_size_left,
            chunk_size_current=args.lc_chunk_size_current,
            chunk_size_right=args.lc_chunk_size_right,
            streaming_type=args.lc_type)

    elif 'conformer' in args.enc_type:
        from neural_sp.models.seq2seq.encoders.conformer import ConformerEncoder
        encoder = ConformerEncoder(
            input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim,
            enc_type=args.enc_type,
            n_heads=args.transformer_enc_n_heads,
            kernel_size=args.conformer_kernel_size,
            n_layers=args.enc_n_layers,
            n_layers_sub1=args.enc_n_layers_sub1,
            n_layers_sub2=args.enc_n_layers_sub2,
            d_model=args.transformer_enc_d_model,
            d_ff=args.transformer_enc_d_ff,
            ffn_bottleneck_dim=args.transformer_ffn_bottleneck_dim,
            ffn_activation='swish',
            pe_type=args.transformer_enc_pe_type,
            layer_norm_eps=args.transformer_layer_norm_eps,
            last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0,
            dropout_in=args.dropout_in,
            dropout=args.dropout_enc,
            dropout_att=args.dropout_att,
            dropout_layer=args.dropout_enc_layer,
            subsample=args.subsample,
            subsample_type=args.subsample_type,
            n_stacks=args.n_stacks,
            n_splices=args.n_splices,
            conv_in_channel=args.conv_in_channel,
            conv_channels=args.conv_channels,
            conv_kernel_sizes=args.conv_kernel_sizes,
            conv_strides=args.conv_strides,
            conv_poolings=args.conv_poolings,
            conv_batch_norm=args.conv_batch_norm,
            conv_layer_norm=args.conv_layer_norm,
            conv_bottleneck_dim=args.conv_bottleneck_dim,
            conv_param_init=args.param_init,
            task_specific_layer=args.task_specific_layer,
            param_init=args.transformer_param_init,
            clamp_len=args.transformer_enc_clamp_len,
            lookahead=args.transformer_enc_lookaheads,
            chunk_size_left=args.lc_chunk_size_left,
            chunk_size_current=args.lc_chunk_size_current,
            chunk_size_right=args.lc_chunk_size_right,
            streaming_type=args.lc_type)

    else:
        from neural_sp.models.seq2seq.encoders.rnn import RNNEncoder
        encoder = RNNEncoder(
            input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim,
            enc_type=args.enc_type,
            n_units=args.enc_n_units,
            n_projs=args.enc_n_projs,
            last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0,
            n_layers=args.enc_n_layers,
            n_layers_sub1=args.enc_n_layers_sub1,
            n_layers_sub2=args.enc_n_layers_sub2,
            dropout_in=args.dropout_in,
            dropout=args.dropout_enc,
            subsample=args.subsample,
            subsample_type=args.subsample_type,
            n_stacks=args.n_stacks,
            n_splices=args.n_splices,
            conv_in_channel=args.conv_in_channel,
            conv_channels=args.conv_channels,
            conv_kernel_sizes=args.conv_kernel_sizes,
            conv_strides=args.conv_strides,
            conv_poolings=args.conv_poolings,
            conv_batch_norm=args.conv_batch_norm,
            conv_layer_norm=args.conv_layer_norm,
            conv_bottleneck_dim=args.conv_bottleneck_dim,
            bidir_sum_fwd_bwd=args.bidirectional_sum_fwd_bwd,
            task_specific_layer=args.task_specific_layer,
            param_init=args.param_init,
            chunk_size_left=args.lc_chunk_size_left,
            chunk_size_right=args.lc_chunk_size_right,
            rsp_prob=args.rsp_prob_enc)

    return encoder
示例#3
0
def select_encoder(args):

    if 'transformer' in args.enc_type:
        encoder = TransformerEncoder(
            input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim,
            attn_type=args.transformer_attn_type,
            attn_n_heads=args.transformer_attn_n_heads,
            n_layers=args.enc_n_layers,
            d_model=args.d_model,
            d_ff=args.d_ff,
            pe_type=args.pe_type,
            layer_norm_eps=args.layer_norm_eps,
            dropout_in=args.dropout_in,
            dropout=args.dropout_enc,
            dropout_att=args.dropout_att,
            last_proj_dim=args.d_model if 'transformer' in args.dec_type else args.dec_n_units,
            n_stacks=args.n_stacks,
            n_splices=args.n_splices,
            conv_in_channel=args.conv_in_channel,
            conv_channels=args.conv_channels,
            conv_kernel_sizes=args.conv_kernel_sizes,
            conv_strides=args.conv_strides,
            conv_poolings=args.conv_poolings,
            conv_batch_norm=args.conv_batch_norm,
            conv_residual=args.conv_residual,
            conv_bottleneck_dim=args.conv_bottleneck_dim,
            param_init=args.param_init)
    else:
        subsample = [1] * args.enc_n_layers
        for l, s in enumerate(list(map(int, args.subsample.split('_')[:args.enc_n_layers]))):
            subsample[l] = s
        encoder = RNNEncoder(
            input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim,
            rnn_type=args.enc_type,
            n_units=args.enc_n_units,
            n_projs=args.enc_n_projs,
            n_layers=args.enc_n_layers,
            n_layers_sub1=args.enc_n_layers_sub1,
            n_layers_sub2=args.enc_n_layers_sub2,
            dropout_in=args.dropout_in,
            dropout=args.dropout_enc,
            subsample=subsample,
            subsample_type=args.subsample_type,
            last_proj_dim=args.d_model if 'transformer' in args.dec_type else args.dec_n_units,
            n_stacks=args.n_stacks,
            n_splices=args.n_splices,
            conv_in_channel=args.conv_in_channel,
            conv_channels=args.conv_channels,
            conv_kernel_sizes=args.conv_kernel_sizes,
            conv_strides=args.conv_strides,
            conv_poolings=args.conv_poolings,
            conv_batch_norm=args.conv_batch_norm,
            conv_residual=args.conv_residual,
            conv_bottleneck_dim=args.conv_bottleneck_dim,
            residual=args.enc_residual,
            nin=args.enc_nin,
            task_specific_layer=args.task_specific_layer,
            param_init=args.param_init)
        # NOTE: pure Conv/TDS/GatedConv encoders are also included

    return encoder
示例#4
0
    def __init__(self, args):

        super(ModelBase, self).__init__()

        # for encoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru']:
            self.enc_n_units *= 2
        self.bridge_layer = args.bridge_layer

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for attention layer
        self.attn_n_heads = args.attn_n_heads

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # Feature extraction
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0)

        # Encoder
        if args.enc_type == 'transformer':
            self.enc = TransformerEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                attn_type=args.transformer_attn_type,
                attn_n_heads=args.transformer_attn_n_heads,
                n_layers=args.transformer_enc_n_layers,
                d_model=args.d_model,
                d_ff=args.d_ff,
                # pe_type=args.pe_type,
                pe_type=False,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                dropout_att=args.dropout_att,
                layer_norm_eps=args.layer_norm_eps,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim)
        else:
            self.enc = RNNEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                rnn_type=args.enc_type,
                n_units=args.enc_n_units,
                n_projs=args.enc_n_projs,
                n_layers=args.enc_n_layers,
                n_layers_sub1=args.enc_n_layers_sub1,
                n_layers_sub2=args.enc_n_layers_sub2,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                subsample=list(map(int, args.subsample.split('_'))) + [1] *
                (args.enc_n_layers - len(args.subsample.split('_'))),
                subsample_type=args.subsample_type,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim,
                residual=args.enc_residual,
                nin=args.enc_nin,
                task_specific_layer=args.task_specific_layer)
            # NOTE: pure CNN/TDS encoders are also included

        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # Bridge layer between the encoder and decoder
        self.is_bridge = False
        if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer']
                and args.ctc_weight < 1
            ) or args.dec_type == 'transformer' or args.bridge_layer:
            self.bridge = LinearND(self.enc.output_dim,
                                   args.d_model if args.dec_type
                                   == 'transformer' else args.dec_n_units,
                                   dropout=args.dropout_enc)
            self.is_bridge = True
            if self.sub1_weight > 0:
                self.bridge_sub1 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            if self.sub2_weight > 0:
                self.bridge_sub2 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            self.enc_n_units = args.dec_n_units

        # main task
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Cold fusion
            if args.lm_fusion and dir == 'fwd':
                lm = RNNLM(args.lm_conf)
                lm, _ = load_checkpoint(lm, args.lm_fusion)
            else:
                args.lm_conf = False
                lm = None
            # TODO(hirofumi): cold fusion for backward RNNLM

            # Decoder
            if args.dec_type == 'transformer':
                dec = TransformerDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.transformer_attn_type,
                    attn_n_heads=args.transformer_attn_n_heads,
                    n_layers=args.transformer_dec_n_layers,
                    d_model=args.d_model,
                    d_ff=args.d_ff,
                    pe_type=args.pe_type,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    lsm_prob=args.lsm_prob,
                    layer_norm_eps=args.layer_norm_eps,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    backward=(dir == 'bwd'),
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch)
            else:
                dec = RNNDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_n_channels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_n_heads=args.attn_n_heads,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    loop_type=args.dec_loop_type,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    ss_prob=args.ss_prob,
                    ss_type=args.ss_type,
                    lsm_prob=args.lsm_prob,
                    fl_weight=args.focal_loss_weight,
                    fl_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    input_feeding=args.input_feeding,
                    backward=(dir == 'bwd'),
                    # lm=args.lm_conf,
                    lm=lm,  # TODO(hirofumi): load RNNLM in the model init.
                    lm_fusion_type=args.lm_fusion_type,
                    contextualize=args.contextualize,
                    lm_init=args.lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    adaptive_softmax=args.adaptive_softmax)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                if args.dec_type == 'transformer':
                    raise NotImplementedError
                else:
                    dec_sub = RNNDecoder(
                        eos=self.eos,
                        unk=self.unk,
                        pad=self.pad,
                        blank=self.blank,
                        enc_n_units=self.enc_n_units,
                        attn_type=args.attn_type,
                        attn_dim=args.attn_dim,
                        attn_sharpening_factor=args.attn_sharpening,
                        attn_sigmoid_smoothing=args.attn_sigmoid,
                        attn_conv_out_channels=args.attn_conv_n_channels,
                        attn_conv_kernel_size=args.attn_conv_width,
                        attn_n_heads=1,
                        rnn_type=args.dec_type,
                        n_units=args.dec_n_units,
                        n_projs=args.dec_n_projs,
                        n_layers=args.dec_n_layers,
                        loop_type=args.dec_loop_type,
                        residual=args.dec_residual,
                        bottleneck_dim=args.dec_bottleneck_dim,
                        emb_dim=args.emb_dim,
                        tie_embedding=args.tie_embedding,
                        vocab=getattr(self, 'vocab_' + sub),
                        dropout=args.dropout_dec,
                        dropout_emb=args.dropout_emb,
                        dropout_att=args.dropout_att,
                        ss_prob=args.ss_prob,
                        ss_type=args.ss_type,
                        lsm_prob=args.lsm_prob,
                        fl_weight=args.focal_loss_weight,
                        fl_gamma=args.focal_loss_gamma,
                        ctc_weight=getattr(self, 'ctc_weight_' + sub),
                        ctc_fc_list=[
                            int(fc) for fc in getattr(args, 'ctc_fc_list_' +
                                                      sub).split('_')
                        ] if getattr(args, 'ctc_fc_list_' + sub) is not None
                        and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else
                        [],
                        input_feeding=args.input_feeding,
                        global_weight=getattr(self, sub + '_weight'),
                        mtl_per_batch=args.mtl_per_batch)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed_in = dec.embed
            else:
                self.embed_in = Embedding(vocab=args.vocab_sub1,
                                          emb_dim=args.emb_dim,
                                          dropout=args.dropout_emb,
                                          ignore_index=self.pad)

        # Initialize parameters in CNN layers
        self.reset_parameters(
            args.param_init,
            #   dist='xavier_uniform',
            #   dist='kaiming_uniform',
            dist='lecun',
            keys=['conv'],
            ignore_keys=['score'])

        # Initialize parameters in the encoder
        if args.enc_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['enc'],
                                  ignore_keys=['embed_in'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed_in'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['enc'],
                                  ignore_keys=['conv'])

        # Initialize parameters in the decoder
        if args.dec_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['dec'],
                                  ignore_keys=['embed'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['dec'])

        # Initialize bias vectors with zero
        self.reset_parameters(0, dist='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Initialize bias in gating with -1 for cold fusion
        if args.lm_fusion:
            self.reset_parameters(-1,
                                  dist='constant',
                                  keys=['linear_lm_gate.fc.bias'])

        if args.lm_fusion_type == 'deep' and args.lm_fusion:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
示例#5
0
class Seq2seq(ModelBase):
    """Attention-based RNN sequence-to-sequence model (including CTC)."""
    def __init__(self, args):

        super(ModelBase, self).__init__()

        # for encoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru']:
            self.enc_n_units *= 2
        self.bridge_layer = args.bridge_layer

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for attention layer
        self.attn_n_heads = args.attn_n_heads

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # Feature extraction
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0)

        # Encoder
        if args.enc_type == 'transformer':
            self.enc = TransformerEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                attn_type=args.transformer_attn_type,
                attn_n_heads=args.transformer_attn_n_heads,
                n_layers=args.transformer_enc_n_layers,
                d_model=args.d_model,
                d_ff=args.d_ff,
                # pe_type=args.pe_type,
                pe_type=False,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                dropout_att=args.dropout_att,
                layer_norm_eps=args.layer_norm_eps,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim)
        else:
            self.enc = RNNEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                rnn_type=args.enc_type,
                n_units=args.enc_n_units,
                n_projs=args.enc_n_projs,
                n_layers=args.enc_n_layers,
                n_layers_sub1=args.enc_n_layers_sub1,
                n_layers_sub2=args.enc_n_layers_sub2,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                subsample=list(map(int, args.subsample.split('_'))) + [1] *
                (args.enc_n_layers - len(args.subsample.split('_'))),
                subsample_type=args.subsample_type,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim,
                residual=args.enc_residual,
                nin=args.enc_nin,
                task_specific_layer=args.task_specific_layer)
            # NOTE: pure CNN/TDS encoders are also included

        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # Bridge layer between the encoder and decoder
        self.is_bridge = False
        if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer']
                and args.ctc_weight < 1
            ) or args.dec_type == 'transformer' or args.bridge_layer:
            self.bridge = LinearND(self.enc.output_dim,
                                   args.d_model if args.dec_type
                                   == 'transformer' else args.dec_n_units,
                                   dropout=args.dropout_enc)
            self.is_bridge = True
            if self.sub1_weight > 0:
                self.bridge_sub1 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            if self.sub2_weight > 0:
                self.bridge_sub2 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            self.enc_n_units = args.dec_n_units

        # main task
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Cold fusion
            if args.lm_fusion and dir == 'fwd':
                lm = RNNLM(args.lm_conf)
                lm, _ = load_checkpoint(lm, args.lm_fusion)
            else:
                args.lm_conf = False
                lm = None
            # TODO(hirofumi): cold fusion for backward RNNLM

            # Decoder
            if args.dec_type == 'transformer':
                dec = TransformerDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.transformer_attn_type,
                    attn_n_heads=args.transformer_attn_n_heads,
                    n_layers=args.transformer_dec_n_layers,
                    d_model=args.d_model,
                    d_ff=args.d_ff,
                    pe_type=args.pe_type,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    lsm_prob=args.lsm_prob,
                    layer_norm_eps=args.layer_norm_eps,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    backward=(dir == 'bwd'),
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch)
            else:
                dec = RNNDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_n_channels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_n_heads=args.attn_n_heads,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    loop_type=args.dec_loop_type,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    ss_prob=args.ss_prob,
                    ss_type=args.ss_type,
                    lsm_prob=args.lsm_prob,
                    fl_weight=args.focal_loss_weight,
                    fl_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    input_feeding=args.input_feeding,
                    backward=(dir == 'bwd'),
                    # lm=args.lm_conf,
                    lm=lm,  # TODO(hirofumi): load RNNLM in the model init.
                    lm_fusion_type=args.lm_fusion_type,
                    contextualize=args.contextualize,
                    lm_init=args.lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    adaptive_softmax=args.adaptive_softmax)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                if args.dec_type == 'transformer':
                    raise NotImplementedError
                else:
                    dec_sub = RNNDecoder(
                        eos=self.eos,
                        unk=self.unk,
                        pad=self.pad,
                        blank=self.blank,
                        enc_n_units=self.enc_n_units,
                        attn_type=args.attn_type,
                        attn_dim=args.attn_dim,
                        attn_sharpening_factor=args.attn_sharpening,
                        attn_sigmoid_smoothing=args.attn_sigmoid,
                        attn_conv_out_channels=args.attn_conv_n_channels,
                        attn_conv_kernel_size=args.attn_conv_width,
                        attn_n_heads=1,
                        rnn_type=args.dec_type,
                        n_units=args.dec_n_units,
                        n_projs=args.dec_n_projs,
                        n_layers=args.dec_n_layers,
                        loop_type=args.dec_loop_type,
                        residual=args.dec_residual,
                        bottleneck_dim=args.dec_bottleneck_dim,
                        emb_dim=args.emb_dim,
                        tie_embedding=args.tie_embedding,
                        vocab=getattr(self, 'vocab_' + sub),
                        dropout=args.dropout_dec,
                        dropout_emb=args.dropout_emb,
                        dropout_att=args.dropout_att,
                        ss_prob=args.ss_prob,
                        ss_type=args.ss_type,
                        lsm_prob=args.lsm_prob,
                        fl_weight=args.focal_loss_weight,
                        fl_gamma=args.focal_loss_gamma,
                        ctc_weight=getattr(self, 'ctc_weight_' + sub),
                        ctc_fc_list=[
                            int(fc) for fc in getattr(args, 'ctc_fc_list_' +
                                                      sub).split('_')
                        ] if getattr(args, 'ctc_fc_list_' + sub) is not None
                        and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else
                        [],
                        input_feeding=args.input_feeding,
                        global_weight=getattr(self, sub + '_weight'),
                        mtl_per_batch=args.mtl_per_batch)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed_in = dec.embed
            else:
                self.embed_in = Embedding(vocab=args.vocab_sub1,
                                          emb_dim=args.emb_dim,
                                          dropout=args.dropout_emb,
                                          ignore_index=self.pad)

        # Initialize parameters in CNN layers
        self.reset_parameters(
            args.param_init,
            #   dist='xavier_uniform',
            #   dist='kaiming_uniform',
            dist='lecun',
            keys=['conv'],
            ignore_keys=['score'])

        # Initialize parameters in the encoder
        if args.enc_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['enc'],
                                  ignore_keys=['embed_in'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed_in'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['enc'],
                                  ignore_keys=['conv'])

        # Initialize parameters in the decoder
        if args.dec_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['dec'],
                                  ignore_keys=['embed'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['dec'])

        # Initialize bias vectors with zero
        self.reset_parameters(0, dist='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Initialize bias in gating with -1 for cold fusion
        if args.lm_fusion:
            self.reset_parameters(-1,
                                  dist='constant',
                                  keys=['linear_lm_gate.fc.bias'])

        if args.lm_fusion_type == 'deep' and args.lm_fusion:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False

    def scheduled_sampling_trigger(self):
        # main task
        directions = []
        if self.fwd_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            getattr(self, 'dec_' + dir).start_scheduled_sampling()

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                directions = []
                if getattr(self, 'fwd_weight_' + sub) > 0:
                    directions.append('fwd')
                for dir_sub in directions:
                    getattr(self, 'dec_' + dir_sub + '_' +
                            sub).start_scheduled_sampling()

    def forward(self, batch, reporter=None, task='all', is_eval=False):
        """Forward computation.

        Args:
            batch (dict):
                xs (list): input data of size `[T, input_dim]`
                xlens (list): lengths of each element in xs
                ys (list): reference labels in the main task of size `[L]`
                ys_sub1 (list): reference labels in the 1st auxiliary task of size `[L_sub1]`
                ys_sub2 (list): reference labels in the 2nd auxiliary task of size `[L_sub2]`
                utt_ids (list): name of utterances
                speakers (list): name of speakers
            reporter ():
            task (str): all or ys* or ys_sub*
            is_eval (bool): the history will not be saved.
                This should be used in inference model for memory efficiency.
        Returns:
            loss (FloatTensor): `[1]`
            reporter ():

        """
        if is_eval:
            self.eval()
            with torch.no_grad():
                loss, reporter = self._forward(batch, task, reporter)
        else:
            self.train()
            loss, reporter = self._forward(batch, task, reporter)

        return loss, reporter

    def _forward(self, batch, task, reporter):
        # Encode input features
        if self.input_type == 'speech':
            if self.mtl_per_batch:
                flip = True if 'bwd' in task else False
                enc_outs = self.encode(batch['xs'], task, flip=flip)
            else:
                flip = True if self.bwd_weight == 1 else False
                enc_outs = self.encode(batch['xs'], 'all', flip=flip)
        else:
            enc_outs = self.encode(batch['ys_sub1'])

        observation = {}
        loss = torch.zeros((1, ), dtype=torch.float32).cuda(self.device_id)

        # for the forward decoder in the main task
        if (self.fwd_weight > 0 or self.ctc_weight > 0) and task in [
                'all', 'ys', 'ys.ctc', 'ys.lmobj'
        ]:
            loss_fwd, obs_fwd = self.dec_fwd(enc_outs['ys']['xs'],
                                             enc_outs['ys']['xlens'],
                                             batch['ys'], task,
                                             batch['ys_hist'])
            loss += loss_fwd
            observation['loss.att'] = obs_fwd['loss_att']
            observation['loss.ctc'] = obs_fwd['loss_ctc']
            observation['loss.lmobj'] = obs_fwd['loss_lmobj']
            observation['acc.att'] = obs_fwd['acc_att']
            observation['acc.lmobj'] = obs_fwd['acc_lmobj']
            observation['ppl.att'] = obs_fwd['ppl_att']
            observation['ppl.lmobj'] = obs_fwd['ppl_lmobj']

        # for the backward decoder in the main task
        if self.bwd_weight > 0 and task in ['all', 'ys.bwd']:
            loss_bwd, obs_bwd = self.dec_bwd(enc_outs['ys']['xs'],
                                             enc_outs['ys']['xlens'],
                                             batch['ys'], task)
            loss += loss_bwd
            observation['loss.att-bwd'] = obs_bwd['loss_att']
            observation['loss.ctc-bwd'] = obs_bwd['loss_ctc']
            observation['loss.lmobj-bwd'] = obs_bwd['loss_lmobj']
            observation['acc.att-bwd'] = obs_bwd['acc_att']
            observation['acc.lmobj-bwd'] = obs_bwd['acc_lmobj']
            observation['ppl.att-bwd'] = obs_bwd['ppl_att']
            observation['ppl.lmobj-bwd'] = obs_bwd['ppl_lmobj']

        # only fwd for sub tasks
        for sub in ['sub1', 'sub2']:
            # for the forward decoder in the sub tasks
            if (getattr(self, 'fwd_weight_' + sub) > 0
                    or getattr(self, 'ctc_weight_' + sub) > 0) and task in [
                        'all', 'ys_' + sub, 'ys_' + sub + '.ctc',
                        'ys_' + sub + '.lmobj'
                    ]:
                loss_sub, obs_fwd_sub = getattr(self, 'dec_fwd_' + sub)(
                    enc_outs['ys_' + sub]['xs'],
                    enc_outs['ys_' + sub]['xlens'], batch['ys_' + sub], task)
                loss += loss_sub
                observation['loss.att-' + sub] = obs_fwd_sub['loss_att']
                observation['loss.ctc-' + sub] = obs_fwd_sub['loss_ctc']
                observation['loss.lmobj-' + sub] = obs_fwd_sub['loss_lmobj']
                observation['acc.att-' + sub] = obs_fwd_sub['acc_att']
                observation['acc.lmobj-' + sub] = obs_fwd_sub['acc_lmobj']
                observation['ppl.att-' + sub] = obs_fwd_sub['ppl_att']
                observation['ppl.lmobj-' + sub] = obs_fwd_sub['ppl_lmobj']

        if reporter is not None:
            is_eval = not self.training
            reporter.add(observation, is_eval)

        return loss, reporter

    def encode(self, xs, task='all', flip=False):
        """Encode acoustic or text features.

        Args:
            xs (list): A list of length `[B]`, which contains Tensor of size `[T, input_dim]`
            task (str): all or ys* or ys_sub1* or ys_sub2*
            flip (bool): if True, flip acoustic features in the time-dimension
        Returns:
            enc_outs (dict):

        """
        if 'lmobj' in task:
            eouts = {
                'ys': {
                    'xs': None,
                    'xlens': None
                },
                'ys_sub1': {
                    'xs': None,
                    'xlens': None
                },
                'ys_sub2': {
                    'xs': None,
                    'xlens': None
                }
            }
            return eouts
        else:
            if self.input_type == 'speech':
                # Frame stacking
                if self.n_stacks > 1:
                    xs = [
                        stack_frame(x, self.n_stacks, self.n_skips) for x in xs
                    ]

                # Splicing
                if self.n_splices > 1:
                    xs = [splice(x, self.n_splices, self.n_stacks) for x in xs]

                xlens = [len(x) for x in xs]
                # Flip acoustic features in the reverse order
                if flip:
                    xs = [
                        torch.from_numpy(np.flip(
                            x, axis=0).copy()).float().cuda(self.device_id)
                        for x in xs
                    ]
                else:
                    xs = [np2tensor(x, self.device_id).float() for x in xs]
                xs = pad_list(xs, 0.0)

            elif self.input_type == 'text':
                xlens = [len(x) for x in xs]
                xs = [
                    np2tensor(np.fromiter(x, dtype=np.int64),
                              self.device_id).long() for x in xs
                ]
                xs = pad_list(xs, self.pad)
                xs = self.embed_in(xs)

            # sequence summary network
            if self.ssn is not None:
                xs += self.ssn(xs, xlens)

            # encoder
            enc_outs = self.enc(xs, xlens, task.split('.')[0])

            if self.main_weight < 1 and self.enc_type in [
                    'conv', 'tds', 'gated_conv', 'transformer'
            ]:
                for sub in ['sub1', 'sub2']:
                    enc_outs['ys_' + sub]['xs'] = enc_outs['ys']['xs'].clone()
                    enc_outs['ys_' + sub]['xlens'] = enc_outs['ys']['xlens'][:]

            # Bridge between the encoder and decoder
            if self.main_weight > 0 and self.is_bridge:
                enc_outs['ys']['xs'] = self.bridge(enc_outs['ys']['xs'])
            if self.sub1_weight > 0 and self.is_bridge:
                enc_outs['ys_sub1']['xs'] = self.bridge_sub1(
                    enc_outs['ys_sub1']['xs'])
            if self.sub2_weight > 0 and self.is_bridge:
                enc_outs['ys_sub2']['xs'] = self.bridge_sub2(
                    enc_outs['ys_sub2']['xs'])

            return enc_outs

    def get_ctc_probs(self, xs, task='ys', temperature=1, topk=None):
        self.eval()
        with torch.no_grad():
            enc_outs = self.encode(xs, task)
            dir = 'fwd' if self.fwd_weight >= self.bwd_weight else 'bwd'
            if task == 'ys_sub1':
                dir += '_sub1'
            elif task == 'ys_sub2':
                dir += '_sub2'

            if task == 'ys':
                assert self.ctc_weight > 0
            elif task == 'ys_sub1':
                assert self.ctc_weight_sub1 > 0
            elif task == 'ys_sub2':
                assert self.ctc_weight_sub2 > 0
            ctc_probs, indices_topk = getattr(self,
                                              'dec_' + dir).ctc_probs_topk(
                                                  enc_outs[task]['xs'],
                                                  temperature, topk)
            return ctc_probs, indices_topk, enc_outs[task]['xlens']

    def decode(self,
               xs,
               params,
               idx2token,
               nbest=1,
               exclude_eos=False,
               refs_id=None,
               refs_text=None,
               utt_ids=None,
               speakers=None,
               task='ys',
               ensemble_models=[]):
        """Decoding in the inference stage.

        Args:
            xs (list): A list of length `[B]`, which contains arrays of size `[T, input_dim]`
            params (dict): hyper-parameters for decoding
                beam_width (int): the size of beam
                min_len_ratio (float):
                max_len_ratio (float):
                len_penalty (float): length penalty
                cov_penalty (float): coverage penalty
                cov_threshold (float): threshold for coverage penalty
                lm_weight (float): the weight of RNNLM score
                resolving_unk (bool): not used (to make compatible)
                fwd_bwd_attention (bool):
            idx2token (): converter from index to token
            nbest (int):
            exclude_eos (bool): exclude <eos> from best_hyps_id
            refs_id (list): gold token IDs to compute log likelihood
            refs_text (list): gold transcriptions
            utt_ids (list):
            speakers (list):
            task (str): ys* or ys_sub1* or ys_sub2*
            ensemble_models (list): list of Seq2seq classes
        Returns:
            best_hyps_id (list): A list of length `[B]`, which contains arrays of size `[L]`
            aws (list): A list of length `[B]`, which contains arrays of size `[L, T, n_heads]`

        """
        self.eval()
        with torch.no_grad():
            if task.split('.')[0] == 'ys':
                dir = 'bwd' if self.bwd_weight > 0 and params[
                    'recog_bwd_attention'] else 'fwd'
            elif task.split('.')[0] == 'ys_sub1':
                dir = 'fwd_sub1'
            elif task.split('.')[0] == 'ys_sub2':
                dir = 'fwd_sub2'
            else:
                raise ValueError(task)

            # encode
            if self.input_type == 'speech' and self.mtl_per_batch and 'bwd' in dir:
                enc_outs = self.encode(xs, task, flip=True)
            else:
                enc_outs = self.encode(xs, task, flip=False)

            #########################
            # CTC
            #########################
            if (self.fwd_weight == 0 and self.bwd_weight == 0) or (
                    self.ctc_weight > 0 and params['recog_ctc_weight'] == 1):
                lm = None
                if params['recog_lm_weight'] > 0 and hasattr(
                        self, 'lm_fwd') and self.lm_fwd is not None:
                    lm = getattr(self, 'lm_' + dir)

                best_hyps_id = getattr(self, 'dec_' + dir).decode_ctc(
                    enc_outs[task]['xs'], enc_outs[task]['xlens'],
                    params['recog_beam_width'], lm, params['recog_lm_weight'])
                return best_hyps_id, None, (None, None)

            #########################
            # Attention
            #########################
            else:
                cache_info = (None, None)

                if params['recog_beam_width'] == 1 and not params[
                        'recog_fwd_bwd_attention']:
                    best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy(
                        enc_outs[task]['xs'], enc_outs[task]['xlens'],
                        params['recog_max_len_ratio'], exclude_eos, idx2token,
                        refs_id, speakers, params['recog_oracle'])
                else:
                    assert params['recog_batch_size'] == 1

                    ctc_log_probs = None
                    if params['recog_ctc_weight'] > 0:
                        ctc_log_probs = self.dec_fwd.ctc_log_probs(
                            enc_outs[task]['xs'])

                    # forward-backward decoding
                    if params['recog_fwd_bwd_attention']:
                        # forward decoder
                        lm_fwd, lm_bwd = None, None
                        if params['recog_lm_weight'] > 0 and hasattr(
                                self, 'lm_fwd') and self.lm_fwd is not None:
                            lm_fwd = self.lm_fwd
                            if params['recog_reverse_lm_rescoring'] and hasattr(
                                    self,
                                    'lm_bwd') and self.lm_bwd is not None:
                                lm_bwd = self.lm_bwd

                        # ensemble (forward)
                        ensmbl_eouts_fwd = []
                        ensmbl_elens_fwd = []
                        ensmbl_decs_fwd = []
                        if len(ensemble_models) > 0:
                            for i_e, model in enumerate(ensemble_models):
                                enc_outs_e_fwd = model.encode(xs,
                                                              task,
                                                              flip=False)
                                ensmbl_eouts_fwd += [
                                    enc_outs_e_fwd[task]['xs']
                                ]
                                ensmbl_elens_fwd += [
                                    enc_outs_e_fwd[task]['xlens']
                                ]
                                ensmbl_decs_fwd += [model.dec_fwd]
                                # NOTE: only support for the main task now

                        nbest_hyps_id_fwd, aws_fwd, scores_fwd, cache_info = self.dec_fwd.beam_search(
                            enc_outs[task]['xs'], enc_outs[task]['xlens'],
                            params, idx2token, lm_fwd, lm_bwd, ctc_log_probs,
                            params['recog_beam_width'], False, refs_id,
                            utt_ids, speakers, ensmbl_eouts_fwd,
                            ensmbl_elens_fwd, ensmbl_decs_fwd)

                        # backward decoder
                        lm_bwd, lm_fwd = None, None
                        if params['recog_lm_weight'] > 0 and hasattr(
                                self, 'lm_bwd') and self.lm_bwd is not None:
                            lm_bwd = self.lm_bwd
                            if params['recog_reverse_lm_rescoring'] and hasattr(
                                    self,
                                    'lm_fwd') and self.lm_fwd is not None:
                                lm_fwd = self.lm_fwd

                        # ensemble (backward)
                        ensmbl_eouts_bwd = []
                        ensmbl_elens_bwd = []
                        ensmbl_decs_bwd = []
                        if len(ensemble_models) > 0:
                            for i_e, model in enumerate(ensemble_models):
                                if self.input_type == 'speech' and self.mtl_per_batch:
                                    enc_outs_e_bwd = model.encode(xs,
                                                                  task,
                                                                  flip=True)
                                else:
                                    enc_outs_e_bwd = model.encode(xs,
                                                                  task,
                                                                  flip=False)
                                ensmbl_eouts_bwd += [
                                    enc_outs_e_bwd[task]['xs']
                                ]
                                ensmbl_elens_bwd += [
                                    enc_outs_e_bwd[task]['xlens']
                                ]
                                ensmbl_decs_bwd += [model.dec_bwd]
                                # NOTE: only support for the main task now
                                # TODO(hirofumi): merge with the forward for the efficiency

                        flip = False
                        if self.input_type == 'speech' and self.mtl_per_batch:
                            flip = True
                            enc_outs_bwd = self.encode(xs, task, flip=True)
                        else:
                            enc_outs_bwd = enc_outs
                        nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search(
                            enc_outs_bwd[task]['xs'], enc_outs[task]['xlens'],
                            params, idx2token, lm_bwd, lm_fwd, ctc_log_probs,
                            params['recog_beam_width'], False, refs_id,
                            utt_ids, speakers, ensmbl_eouts_bwd,
                            ensmbl_elens_bwd, ensmbl_decs_bwd)

                        # forward-backward attention
                        best_hyps_id = fwd_bwd_attention(
                            nbest_hyps_id_fwd, aws_fwd, scores_fwd,
                            nbest_hyps_id_bwd, aws_bwd, scores_bwd, flip,
                            self.eos, params['recog_gnmt_decoding'],
                            params['recog_length_penalty'], idx2token, refs_id)
                        aws = None
                    else:
                        # ensemble
                        ensmbl_eouts = []
                        ensmbl_elens = []
                        ensmbl_decs = []
                        if len(ensemble_models) > 0:
                            for i_e, model in enumerate(ensemble_models):
                                if model.input_type == 'speech' and model.mtl_per_batch and 'bwd' in dir:
                                    enc_outs_e = model.encode(xs,
                                                              task,
                                                              flip=True)
                                else:
                                    enc_outs_e = model.encode(xs,
                                                              task,
                                                              flip=False)
                                ensmbl_eouts += [enc_outs_e[task]['xs']]
                                ensmbl_elens += [enc_outs_e[task]['xlens']]
                                ensmbl_decs += [getattr(model, 'dec_' + dir)]
                                # NOTE: only support for the main task now

                        lm, lm_rev = None, None
                        if params['recog_lm_weight'] > 0 and hasattr(
                                self, 'lm_' + dir) and getattr(
                                    self, 'lm_' + dir) is not None:
                            lm = getattr(self, 'lm_' + dir)
                            if params['recog_reverse_lm_rescoring']:
                                if dir == 'fwd':
                                    lm_rev = self.lm_bwd
                                else:
                                    raise NotImplementedError

                        nbest_hyps_id, aws, scores, cache_info = getattr(
                            self, 'dec_' + dir).beam_search(
                                enc_outs[task]['xs'], enc_outs[task]['xlens'],
                                params, idx2token, lm, lm_rev, ctc_log_probs,
                                nbest, exclude_eos, refs_id, utt_ids, speakers,
                                ensmbl_eouts, ensmbl_elens, ensmbl_decs)

                        if nbest == 1:
                            best_hyps_id = [hyp[0] for hyp in nbest_hyps_id]
                            aws = [aw[0] for aw in aws]
                        else:
                            return nbest_hyps_id, aws, scores, cache_info
                        # NOTE: nbest >= 2 is used for MWER training only

                return best_hyps_id, aws, cache_info
示例#6
0
    def __init__(self, args):

        super(ModelBase, self).__init__()

        # for encoder
        self.input_type = args.input_type
        assert args.input_type in ['speech', 'text']
        self.input_dim = args.input_dim
        self.num_stack = args.num_stack
        self.num_skip = args.num_skip
        self.num_splice = args.num_splice
        self.enc_type = args.enc_type
        self.enc_num_units = args.enc_num_units
        if args.enc_type in ['blstm', 'bgru']:
            self.enc_num_units *= 2

        # for attention layer
        self.att_num_heads_0 = args.att_num_heads
        self.att_num_heads_1 = args.att_num_heads_sub
        self.share_attention = False

        # for decoder
        self.num_classes = args.num_classes
        self.num_classes_sub = args.num_classes_sub
        self.blank = 0
        self.unk = 1
        self.sos = 2
        self.eos = 3
        self.pad = 4
        # NOTE: these are reserved in advance

        # for CTC
        self.ctc_weight_0 = args.ctc_weight
        self.ctc_weight_1 = args.ctc_weight_sub

        # for backward decoder
        assert 0 <= args.bwd_weight <= 1
        assert 0 <= args.bwd_weight_sub <= 1
        self.fwd_weight_0 = 1 - args.bwd_weight
        self.bwd_weight_0 = args.bwd_weight
        self.fwd_weight_1 = 1 - args.bwd_weight
        self.bwd_weight_1 = args.bwd_weight

        # for the sub task
        self.main_task_weight = args.main_task_weight

        # Encoder
        if args.enc_type in ['blstm', 'lstm', 'bgru', 'gru']:
            self.enc = RNNEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                rnn_type=args.enc_type,
                num_units=args.enc_num_units,
                num_projs=args.enc_num_projs,
                num_layers=args.enc_num_layers,
                num_layers_sub=args.enc_num_layers_sub,
                dropout_in=args.dropout_in,
                dropout_hidden=args.dropout_enc,
                subsample=args.subsample,
                subsample_type=args.subsample_type,
                batch_first=True,
                num_stack=args.num_stack,
                num_splice=args.num_splice,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                residual=args.enc_residual,
                nin=0,
                num_projs_final=args.dec_num_units if args.bridge_layer else 0)
        elif args.enc_type == 'cnn':
            assert args.num_stack == 1 and args.num_splice == 1
            self.enc = CNNEncoder(input_dim=args.input_dim if args.input_type
                                  == 'speech' else args.emb_dim,
                                  in_channel=args.conv_in_channel,
                                  channels=args.conv_channels,
                                  kernel_sizes=args.conv_kernel_sizes,
                                  strides=args.conv_strides,
                                  poolings=args.conv_poolings,
                                  dropout_in=args.dropout_in,
                                  dropout_hidden=args.dropout_enc,
                                  num_projs_final=args.dec_num_units,
                                  batch_norm=args.conv_batch_norm)
        else:
            raise NotImplementedError()

        # Bridge layer between the encoder and decoder
        if args.enc_type == 'cnn':
            self.enc_num_units = args.dec_num_units
        elif args.bridge_layer:
            self.bridge_0 = LinearND(self.enc_num_units, args.dec_num_units)
            self.enc_num_units = args.dec_num_units
        else:
            self.bridge_0 = lambda x: x

        directions = []
        if self.fwd_weight_0 > 0:
            directions.append('fwd')
        if self.bwd_weight_0 > 0:
            directions.append('bwd')
        for dir in directions:
            if args.ctc_weight < 1:
                # Attention layer
                if args.att_num_heads > 1:
                    attention = MultiheadAttentionMechanism(
                        enc_num_units=self.enc_num_units,
                        dec_num_units=args.dec_num_units,
                        att_type=args.att_type,
                        att_dim=args.att_dim,
                        sharpening_factor=args.att_sharpening_factor,
                        sigmoid_smoothing=args.att_sigmoid_smoothing,
                        conv_out_channels=args.att_conv_num_channels,
                        conv_kernel_size=args.att_conv_width,
                        num_heads=args.att_num_heads)
                else:
                    attention = AttentionMechanism(
                        enc_num_units=self.enc_num_units,
                        dec_num_units=args.dec_num_units,
                        att_type=args.att_type,
                        att_dim=args.att_dim,
                        sharpening_factor=args.att_sharpening_factor,
                        sigmoid_smoothing=args.att_sigmoid_smoothing,
                        conv_out_channels=args.att_conv_num_channels,
                        conv_kernel_size=args.att_conv_width)

                # Cold fusion
                # if args.rnnlm_cf is not None and dir == 'fwd':
                #     raise NotImplementedError()
                #     # TODO(hirofumi): cold fusion for backward RNNLM
                # else:
                #     args.rnnlm_cf = None
                #
                # # RNNLM initialization
                # if args.rnnlm_config_init is not None and dir == 'fwd':
                #     raise NotImplementedError()
                #     # TODO(hirofumi): RNNLM initialization for backward RNNLM
                # else:
                #     args.rnnlm_init = None
            else:
                attention = None

            # Decoder
            decoder = Decoder(
                attention=attention,
                sos=self.sos,
                eos=self.eos,
                pad=self.pad,
                enc_num_units=self.enc_num_units,
                rnn_type=args.dec_type,
                num_units=args.dec_num_units,
                num_layers=args.dec_num_layers,
                residual=args.dec_residual,
                emb_dim=args.emb_dim,
                num_classes=self.num_classes,
                logits_temp=args.logits_temp,
                dropout_dec=args.dropout_dec,
                dropout_emb=args.dropout_emb,
                ss_prob=args.ss_prob,
                lsm_prob=args.lsm_prob,
                lsm_type=args.lsm_type,
                init_with_enc=args.init_with_enc,
                ctc_weight=args.ctc_weight if dir == 'fwd' else 0,
                ctc_fc_list=args.ctc_fc_list,
                backward=(dir == 'bwd'),
                rnnlm_cf=args.rnnlm_cf,
                cold_fusion_type=args.cold_fusion_type,
                internal_lm=args.internal_lm,
                rnnlm_init=args.rnnlm_init,
                # rnnlm_weight=args.rnnlm_weight,
                share_softmax=args.share_softmax)
            setattr(self, 'dec_' + dir + '_0', decoder)

        # NOTE: fwd only for the sub task
        if args.main_task_weight < 1:
            if args.ctc_weight_sub < 1:
                # Attention layer
                if args.att_num_heads_sub > 1:
                    attention_sub = MultiheadAttentionMechanism(
                        enc_num_units=self.enc_num_units,
                        dec_num_units=args.dec_num_units,
                        att_type=args.att_type,
                        att_dim=args.att_dim,
                        sharpening_factor=args.att_sharpening_factor,
                        sigmoid_smoothing=args.att_sigmoid_smoothing,
                        conv_out_channels=args.att_conv_num_channels,
                        conv_kernel_size=args.att_conv_width,
                        num_heads=args.att_num_heads_sub)
                else:
                    attention_sub = AttentionMechanism(
                        enc_num_units=self.enc_num_units,
                        dec_num_units=args.dec_num_units,
                        att_type=args.att_type,
                        att_dim=args.att_dim,
                        sharpening_factor=args.att_sharpening_factor,
                        sigmoid_smoothing=args.att_sigmoid_smoothing,
                        conv_out_channels=args.att_conv_num_channels,
                        conv_kernel_size=args.att_conv_width)
            else:
                attention_sub = None

            # Decoder
            self.dec_fwd_1 = Decoder(attention=attention_sub,
                                     sos=self.sos,
                                     eos=self.eos,
                                     pad=self.pad,
                                     enc_num_units=self.enc_num_units,
                                     rnn_type=args.dec_type,
                                     num_units=args.dec_num_units,
                                     num_layers=args.dec_num_layers,
                                     residual=args.dec_residual,
                                     emb_dim=args.emb_dim,
                                     num_classes=self.num_classes_sub,
                                     logits_temp=args.logits_temp,
                                     dropout_dec=args.dropout_dec,
                                     dropout_emb=args.dropout_emb,
                                     ss_prob=args.ss_prob,
                                     lsm_prob=args.lsm_prob,
                                     lsm_type=args.lsm_type,
                                     init_with_enc=args.init_with_enc,
                                     ctc_weight=args.ctc_weight_sub,
                                     ctc_fc_list=args.ctc_fc_list)  # sub??

        if args.input_type == 'text':
            if args.num_classes == args.num_classes_sub:
                # Share the embedding layer between input and output
                self.embed_in = decoder.emb
            else:
                self.embed_in = Embedding(num_classes=args.num_classes_sub,
                                          emb_dim=args.emb_dim,
                                          dropout=args.dropout_emb,
                                          ignore_index=self.pad)

        # Initialize weight matrices
        self.init_weights(args.param_init,
                          dist=args.param_init_dist,
                          ignore_keys=['bias'])

        # Initialize all biases with 0
        self.init_weights(0, dist='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            # encoder
            if args.enc_type != 'cnn':
                self.init_weights(args.param_init,
                                  dist='orthogonal',
                                  keys=[args.enc_type, 'weight'],
                                  ignore_keys=['bias'])
            # TODO(hirofumi): in case of CNN + LSTM
            # decoder
            self.init_weights(args.param_init,
                              dist='orthogonal',
                              keys=[args.dec_type, 'weight'],
                              ignore_keys=['bias'])

        # Initialize bias in forget gate with 1
        self.init_forget_gate_bias_with_one()

        # Initialize bias in gating with -1
        if args.rnnlm_cf is not None:
            self.init_weights(-1,
                              dist='constant',
                              keys=['cf_fc_lm_gate.fc.bias'])
示例#7
0
    def __init__(self, args):

        super(ModelBase, self).__init__()

        # for encoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.nstacks = args.nstacks
        self.nskips = args.nskips
        self.nsplices = args.nsplices
        self.enc_type = args.enc_type
        self.enc_nunits = args.enc_nunits
        if args.enc_type in ['blstm', 'bgru']:
            self.enc_nunits *= 2
        self.bridge_layer = args.bridge_layer

        # for attention layer
        self.attn_nheads = args.attn_nheads

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.sos = 2  # NOTE: the same index as <eos>
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for CTC
        self.ctc_weight = args.ctc_weight
        self.ctc_weight_sub1 = args.ctc_weight_sub1
        self.ctc_weight_sub2 = args.ctc_weight_sub2

        # for backward decoder
        self.fwd_weight = 1 - args.bwd_weight
        self.fwd_weight_sub1 = 1 - args.bwd_weight_sub1
        self.fwd_weight_sub2 = 1 - args.bwd_weight_sub2
        self.bwd_weight = args.bwd_weight
        self.bwd_weight_sub1 = args.bwd_weight_sub1
        self.bwd_weight_sub2 = args.bwd_weight_sub2

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # Setting for the CNN encoder
        if args.conv_poolings:
            conv_channels = [int(c) for c in args.conv_channels.split('_')] if len(args.conv_channels) > 0 else []
            conv_kernel_sizes = [[int(c.split(',')[0].replace('(', '')), int(c.split(',')[1].replace(')', ''))]
                                 for c in args.conv_kernel_sizes.split('_')] if len(args.conv_kernel_sizes) > 0 else []
            conv_strides = [[int(c.split(',')[0].replace('(', '')), int(c.split(',')[1].replace(')', ''))]
                            for c in args.conv_strides.split('_')] if len(args.conv_strides) > 0 else []
            conv_poolings = [[int(c.split(',')[0].replace('(', '')), int(c.split(',')[1].replace(')', ''))]
                             for c in args.conv_poolings.split('_')] if len(args.conv_poolings) > 0 else []
        else:
            conv_channels = []
            conv_kernel_sizes = []
            conv_strides = []
            conv_poolings = []

        # Encoder
        self.enc = RNNEncoder(
            input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim,
            rnn_type=args.enc_type,
            nunits=args.enc_nunits,
            nprojs=args.enc_nprojs,
            nlayers=args.enc_nlayers,
            nlayers_sub1=args.enc_nlayers_sub1,
            nlayers_sub2=args.enc_nlayers_sub2,
            dropout_in=args.dropout_in,
            dropout=args.dropout_enc,
            subsample=[int(s) for s in args.subsample.split('_')],
            subsample_type=args.subsample_type,
            nstacks=args.nstacks,
            nsplices=args.nsplices,
            conv_in_channel=args.conv_in_channel,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            conv_poolings=conv_poolings,
            conv_batch_norm=args.conv_batch_norm,
            residual=args.enc_residual,
            nin=0,
            layer_norm=args.layer_norm,
            task_specific_layer=args.task_specific_layer and args.ctc_weight > 0,
            task_specific_layer_sub1=args.task_specific_layer,
            task_specific_layer_sub2=args.task_specific_layer)

        # Bridge layer between the encoder and decoder
        if args.enc_type == 'cnn':
            self.bridge = LinearND(self.enc.conv.output_dim, args.dec_nunits,
                                   dropout=args.dropout_enc)
            if self.sub1_weight > 0:
                self.bridge_sub1 = LinearND(self.enc.conv.output_dim, args.dec_nunits,
                                            dropout=args.dropout_enc)
            if self.sub2_weight > 0:
                self.bridge_sub2 = LinearND(self.enc.conv.output_dim, args.dec_nunits,
                                            dropout=args.dropout_enc)
            self.enc_nunits = args.dec_nunits
        elif args.bridge_layer:
            self.bridge = LinearND(self.enc_nunits, args.dec_nunits,
                                   dropout=args.dropout_enc)
            if self.sub1_weight > 0:
                self.bridge_sub1 = LinearND(self.enc_nunits, args.dec_nunits,
                                            dropout=args.dropout_enc)
            if self.sub2_weight > 0:
                self.bridge_sub2 = LinearND(self.enc_nunits, args.dec_nunits,
                                            dropout=args.dropout_enc)
            self.enc_nunits = args.dec_nunits

        # MAIN TASK
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Cold fusion
            if args.rnnlm_cold_fusion and dir == 'fwd':
                logger.inof('cold fusion')
                raise NotImplementedError()
                # TODO(hirofumi): cold fusion for backward RNNLM
            else:
                args.rnnlm_cold_fusion = False

            # TODO(hirofumi): remove later
            if not hasattr(args, 'focal_loss_weight'):
                args.focal_loss_weight = 0.0
                args.focal_loss_gamma = 2.0
            if not hasattr(args, 'tie_embedding'):
                args.tie_embedding = False

            # Decoder
            dec = Decoder(
                sos=self.sos,
                eos=self.eos,
                pad=self.pad,
                enc_nunits=self.enc_nunits,
                attn_type=args.attn_type,
                attn_dim=args.attn_dim,
                attn_sharpening_factor=args.attn_sharpening,
                attn_sigmoid_smoothing=args.attn_sigmoid,
                attn_conv_out_channels=args.attn_conv_nchannels,
                attn_conv_kernel_size=args.attn_conv_width,
                attn_nheads=args.attn_nheads,
                dropout_att=args.dropout_att,
                rnn_type=args.dec_type,
                nunits=args.dec_nunits,
                nlayers=args.dec_nlayers,
                residual=args.dec_residual,
                emb_dim=args.emb_dim,
                tie_embedding=args.tie_embedding,
                vocab=self.vocab,
                logits_temp=args.logits_temp,
                dropout=args.dropout_dec,
                dropout_emb=args.dropout_emb,
                ss_prob=args.ss_prob,
                lsm_prob=args.lsm_prob,
                layer_norm=args.layer_norm,
                fl_weight=args.focal_loss_weight,
                fl_gamma=args.focal_loss_gamma,
                init_with_enc=args.init_with_enc,
                ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                ctc_fc_list=[int(fc) for fc in args.ctc_fc_list.split('_')] if len(args.ctc_fc_list) > 0 else [],
                input_feeding=args.input_feeding,
                backward=(dir == 'bwd'),
                rnnlm_cold_fusion=args.rnnlm_cold_fusion,
                cold_fusion=args.cold_fusion,
                internal_lm=args.internal_lm,
                rnnlm_init=args.rnnlm_init,
                lmobj_weight=args.lmobj_weight,
                share_lm_softmax=args.share_lm_softmax,
                global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                mtl_per_batch=args.mtl_per_batch,
                vocab_char=args.vocab_sub1)
            setattr(self, 'dec_' + dir, dec)

        # sub task (only for fwd)
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                # Decoder
                dec_fwd_sub = Decoder(
                    sos=self.sos,
                    eos=self.eos,
                    pad=self.pad,
                    enc_nunits=self.enc_nunits,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_nchannels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_nheads=1,
                    dropout_att=args.dropout_att,
                    rnn_type=args.dec_type,
                    nunits=args.dec_nunits,
                    nlayers=args.dec_nlayers,
                    residual=args.dec_residual,
                    emb_dim=args.emb_dim,
                    tie_embedding=args.tie_embedding,
                    vocab=getattr(self, 'vocab_' + sub),
                    logits_temp=args.logits_temp,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    ss_prob=args.ss_prob,
                    lsm_prob=args.lsm_prob,
                    layer_norm=args.layer_norm,
                    fl_weight=args.focal_loss_weight,
                    fl_gamma=args.focal_loss_gamma,
                    init_with_enc=args.init_with_enc,
                    ctc_weight=getattr(self, 'ctc_weight_' + sub),
                    ctc_fc_list=[int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_')
                                 ] if len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [],
                    input_feeding=args.input_feeding,
                    internal_lm=args.internal_lm,
                    lmobj_weight=getattr(args, 'lmobj_weight_' + sub),
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=getattr(self, sub + '_weight'),
                    mtl_per_batch=args.mtl_per_batch)
                setattr(self, 'dec_fwd_' + sub, dec_fwd_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed_in = dec.embed
            else:
                self.embed_in = Embedding(vocab=args.vocab_sub1,
                                          emb_dim=args.emb_dim,
                                          dropout=args.dropout_emb,
                                          ignore_index=self.pad)

        # Initialize weight matrices
        self.init_weights(args.param_init, dist=args.param_init_dist, ignore_keys=['bias'])

        # Initialize CNN layers like chainer
        self.init_weights(args.param_init, dist='lecun', keys=['conv'], ignore_keys=['score'])

        # Initialize all biases with 0
        self.init_weights(0, dist='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            # encoder
            if args.enc_type != 'cnn':
                self.init_weights(args.param_init, dist='orthogonal',
                                  keys=[args.enc_type, 'weight'], ignore_keys=['bias'])
            # TODO(hirofumi): in case of CNN + LSTM
            # decoder
            self.init_weights(args.param_init, dist='orthogonal',
                              keys=[args.dec_type, 'weight'], ignore_keys=['bias'])

        # Initialize bias in forget gate with 1
        self.init_forget_gate_bias_with_one()

        # Initialize bias in gating with -1
        if args.rnnlm_cold_fusion:
            self.init_weights(-1, dist='constant', keys=['cf_linear_lm_gate.fc.bias'])