Exemplo n.º 1
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 vocab,
                 tie_embedding=False,
                 pe_type='add',
                 layer_norm_eps=1e-12,
                 dropout=0.0,
                 dropout_emb=0.0,
                 dropout_att=0.0,
                 lsm_prob=0.0,
                 focal_loss_weight=0.0,
                 focal_loss_gamma=2.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 adaptive_softmax=False):

        super(TransformerDecoder, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.attn_n_heads = attn_n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.focal_loss_weight = focal_loss_weight
        self.focal_loss_gamma = focal_loss_gamma
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1)

        if ctc_weight < global_weight:
            self.embed = Embedding(
                vocab,
                d_model,
                dropout=0,  # NOTE: do not apply dropout here
                ignore_index=pad)
            self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type)
            self.layers = nn.ModuleList([
                TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                        dropout, dropout_att, layer_norm_eps)
                for _ in range(n_layers)
            ])
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

            if adaptive_softmax:
                self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                    d_model,
                    vocab,
                    cutoffs=[
                        round(self.vocab / 15), 3 * round(self.vocab / 15)
                    ],
                    # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                    div_value=4.0)
                self.output = None
            else:
                self.adaptive_softmax = None
                self.output = Linear(d_model, vocab)

                # Optionally tie weights as in:
                # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
                # https://arxiv.org/abs/1608.05859
                # and
                # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
                # https://arxiv.org/abs/1611.01462
                if tie_embedding:
                    self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters()
Exemplo n.º 2
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger = logging.getLogger('training')
        logger.info(self.__class__.__name__)

        self.save_path = save_path

        self.emb_dim = args.emb_dim
        self.n_units = args.n_units
        self.n_layers = args.n_layers
        self.lsm_prob = args.lsm_prob

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(vocab=self.vocab,
                               emb_dim=args.emb_dim,
                               dropout=args.dropout_in,
                               ignore_index=self.pad)

        model_size = args.lm_type.replace('gated_conv_', '')

        blocks = OrderedDict()
        if model_size == 'custom':
            blocks['conv1'] = GLUBlock(args.kernel_size,
                                       args.emb_dim,
                                       args.n_units,
                                       bottlececk_dim=args.n_projs,
                                       dropout=args.dropout_hidden)
            for l in range(args.n_layers - 1):
                blocks['conv%d' % (l + 2)] = GLUBlock(
                    args.kernel_size,
                    args.n_units,
                    args.n_units,
                    bottlececk_dim=args.n_projs,
                    dropout=args.dropout_hidden)
            last_dim = args.n_units

        elif model_size == '8':
            blocks['conv1'] = GLUBlock(4,
                                       args.emb_dim,
                                       900,
                                       dropout=args.dropout_hidden)
            for i in range(1, 8, 1):
                blocks['conv2-%d' % i] = GLUBlock(4,
                                                  900,
                                                  900,
                                                  dropout=args.dropout_hidden)
            last_dim = 900

        elif model_size == '8B':
            blocks['conv1'] = GLUBlock(1,
                                       args.emb_dim,
                                       512,
                                       dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv2-%d' % i] = GLUBlock(5,
                                                  512,
                                                  512,
                                                  bottlececk_dim=128,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv3-%d' % i] = GLUBlock(5,
                                                  512,
                                                  512,
                                                  bottlececk_dim=256,
                                                  dropout=args.dropout_hidden)
            blocks['conv4'] = GLUBlock(1,
                                       512,
                                       2048,
                                       bottlececk_dim=1024,
                                       dropout=args.dropout_hidden)
            last_dim = 2048

        elif model_size == '9':
            blocks['conv1'] = GLUBlock(4,
                                       args.emb_dim,
                                       807,
                                       dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv2-%d-1' % i] = GLUBlock(
                    4, 807, 807, dropout=args.dropout_hidden)
                blocks['conv2-%d-2' % i] = GLUBlock(
                    4, 807, 807, dropout=args.dropout_hidden)
            last_dim = 807

        elif model_size == '13':
            blocks['conv1'] = GLUBlock(4,
                                       args.emb_dim,
                                       1268,
                                       dropout=args.dropout_hidden)
            for i in range(1, 13, 1):
                blocks['conv2-%d' % i] = GLUBlock(4,
                                                  1268,
                                                  1268,
                                                  dropout=args.dropout_hidden)
            last_dim = 1268

        elif model_size == '14':
            for i in range(1, 4, 1):
                blocks['conv1-%d' % i] = GLUBlock(
                    6,
                    args.emb_dim if i == 1 else 850,
                    850,
                    dropout=args.dropout_hidden)
            blocks['conv2'] = GLUBlock(1,
                                       850,
                                       850,
                                       dropout=args.dropout_hidden)
            for i in range(1, 5, 1):
                blocks['conv3-%d' % i] = GLUBlock(5,
                                                  850,
                                                  850,
                                                  dropout=args.dropout_hidden)
            blocks['conv4'] = GLUBlock(1,
                                       850,
                                       850,
                                       dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv5-%d' % i] = GLUBlock(4,
                                                  850,
                                                  850,
                                                  dropout=args.dropout_hidden)
            blocks['conv6'] = GLUBlock(4,
                                       850,
                                       1024,
                                       dropout=args.dropout_hidden)
            blocks['conv7'] = GLUBlock(4,
                                       1024,
                                       2048,
                                       dropout=args.dropout_hidden)
            last_dim = 2048

        elif model_size == '14B':
            blocks['conv1'] = GLUBlock(5,
                                       args.emb_dim,
                                       512,
                                       dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv2-%d' % i] = GLUBlock(5,
                                                  512,
                                                  512,
                                                  bottlececk_dim=128,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv3-%d' % i] = GLUBlock(5,
                                                  512 if i == 1 else 1024,
                                                  1024,
                                                  bottlececk_dim=512,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 7, 1):
                blocks['conv4-%d' % i] = GLUBlock(5,
                                                  1024 if i == 1 else 2048,
                                                  2048,
                                                  bottlececk_dim=1024,
                                                  dropout=args.dropout_hidden)
            blocks['conv5'] = GLUBlock(5,
                                       2048,
                                       4096,
                                       bottlececk_dim=1024,
                                       dropout=args.dropout_hidden)
            last_dim = 4096

        else:
            raise NotImplementedError(model_size)

        self.blocks = nn.Sequential(blocks)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                last_dim,
                self.vocab,
                # cutoffs=[self.vocab // 10, 3 * self.vocab // 10],
                cutoffs=[self.vocab // 25, self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(last_dim,
                                   self.vocab,
                                   dropout=args.dropout_out)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                if args.n_units != args.emb_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(args.param_init)
Exemplo n.º 3
0
    def __init__(self, args, save_path=None):

        super(ModelBase, self).__init__()

        self.save_path = save_path

        # for encoder, decoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru', 'conv_blstm', 'conv_bgru']:
            self.enc_n_units *= 2
        self.dec_type = args.dec_type

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # Feature extraction
        self.gaussian_noise = args.gaussian_noise
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.is_specaug = args.n_freq_masks > 0 or args.n_time_masks > 0
        self.specaug = None
        if self.is_specaug:
            assert args.n_stacks == 1 and args.n_skips == 1
            assert args.n_splices == 1
            self.specaug = SpecAugment(F=args.freq_width,
                                       T=args.time_width,
                                       n_freq_masks=args.n_freq_masks,
                                       n_time_masks=args.n_time_masks,
                                       p=args.time_width_upper)

        # Frontend
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0,
                                              param_init=args.param_init)

        # Encoder
        self.enc = select_encoder(args)
        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # main task
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Load the LM for LM fusion
            if args.lm_fusion and dir == 'fwd':
                lm_fusion = RNNLM(args.lm_conf)
                lm_fusion, _ = load_checkpoint(lm_fusion, args.lm_fusion)
            else:
                lm_fusion = None
                # TODO(hirofumi): for backward RNNLM

            # Load the LM for LM initialization
            if args.lm_init and dir == 'fwd':
                lm_init = RNNLM(args.lm_conf)
                lm_init, _ = load_checkpoint(lm_init, args.lm_init)
            else:
                lm_init = None
                # TODO(hirofumi): for backward RNNLM

            # Decoder
            if args.dec_type == 'transformer':
                dec = TransformerDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.transformer_attn_type,
                    attn_n_heads=args.transformer_attn_n_heads,
                    n_layers=args.dec_n_layers,
                    d_model=args.d_model,
                    d_ff=args.d_ff,
                    vocab=self.vocab,
                    tie_embedding=args.tie_embedding,
                    pe_type=args.pe_type,
                    layer_norm_eps=args.layer_norm_eps,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    lsm_prob=args.lsm_prob,
                    focal_loss_weight=args.focal_loss_weight,
                    focal_loss_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    backward=(dir == 'bwd'),
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch)
            elif 'transducer' in args.dec_type:
                dec = RNNTransducer(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    lsm_prob=args.lsm_prob,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    lm_init=lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    param_init=args.param_init)
            else:
                dec = RNNDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_n_channels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_n_heads=args.attn_n_heads,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    loop_type=args.dec_loop_type,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    vocab=self.vocab,
                    tie_embedding=args.tie_embedding,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    zoneout=args.zoneout,
                    ss_prob=args.ss_prob,
                    ss_type=args.ss_type,
                    lsm_prob=args.lsm_prob,
                    focal_loss_weight=args.focal_loss_weight,
                    focal_loss_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    input_feeding=args.input_feeding,
                    backward=(dir == 'bwd'),
                    lm_fusion=lm_fusion,
                    lm_fusion_type=args.lm_fusion_type,
                    discourse_aware=args.discourse_aware,
                    lm_init=lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    adaptive_softmax=args.adaptive_softmax,
                    param_init=args.param_init,
                    replace_sos=args.replace_sos)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                if args.dec_type == 'transformer':
                    raise NotImplementedError
                else:
                    dec_sub = RNNDecoder(
                        eos=self.eos,
                        unk=self.unk,
                        pad=self.pad,
                        blank=self.blank,
                        enc_n_units=self.enc_n_units,
                        attn_type=args.attn_type,
                        attn_dim=args.attn_dim,
                        attn_sharpening_factor=args.attn_sharpening,
                        attn_sigmoid_smoothing=args.attn_sigmoid,
                        attn_conv_out_channels=args.attn_conv_n_channels,
                        attn_conv_kernel_size=args.attn_conv_width,
                        attn_n_heads=1,
                        rnn_type=args.dec_type,
                        n_units=args.dec_n_units,
                        n_projs=args.dec_n_projs,
                        n_layers=args.dec_n_layers,
                        loop_type=args.dec_loop_type,
                        residual=args.dec_residual,
                        bottleneck_dim=args.dec_bottleneck_dim,
                        emb_dim=args.emb_dim,
                        tie_embedding=args.tie_embedding,
                        vocab=getattr(self, 'vocab_' + sub),
                        dropout=args.dropout_dec,
                        dropout_emb=args.dropout_emb,
                        dropout_att=args.dropout_att,
                        ss_prob=args.ss_prob,
                        ss_type=args.ss_type,
                        lsm_prob=args.lsm_prob,
                        focal_loss_weight=args.focal_loss_weight,
                        focal_loss_gamma=args.focal_loss_gamma,
                        ctc_weight=getattr(self, 'ctc_weight_' + sub),
                        ctc_lsm_prob=args.ctc_lsm_prob,
                        ctc_fc_list=[
                            int(fc) for fc in getattr(args, 'ctc_fc_list_' +
                                                      sub).split('_')
                        ] if getattr(args, 'ctc_fc_list_' + sub) is not None
                        and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else
                        [],
                        input_feeding=args.input_feeding,
                        global_weight=getattr(self, sub + '_weight'),
                        mtl_per_batch=args.mtl_per_batch,
                        param_init=args.param_init)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed = dec.embed
            else:
                self.embed = Embedding(vocab=args.vocab_sub1,
                                       emb_dim=args.emb_dim,
                                       dropout=args.dropout_emb,
                                       ignore_index=self.pad)

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Fix all parameters except for the gating parts in deep fusion
        if args.lm_fusion_type == 'deep' and args.lm_fusion:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
Exemplo n.º 4
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger = logging.getLogger('training')
        logger.info(self.__class__.__name__)

        self.save_path = save_path

        self.emb_dim = args.emb_dim
        self.rnn_type = args.lm_type
        assert args.lm_type in ['lstm', 'gru']
        self.n_units = args.n_units
        self.n_projs = args.n_projs
        self.n_layers = args.n_layers
        self.residual = args.residual
        self.use_glu = args.use_glu
        self.n_units_cv = args.n_units_null_context
        self.lsm_prob = args.lsm_prob

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(vocab=self.vocab,
                               emb_dim=args.emb_dim,
                               dropout=args.dropout_in,
                               ignore_index=self.pad)

        rnn = nn.LSTM if args.lm_type == 'lstm' else nn.GRU
        self.rnn = nn.ModuleList()
        self.dropout = nn.ModuleList(
            [nn.Dropout(p=args.dropout_hidden) for _ in range(args.n_layers)])
        if args.n_projs > 0:
            self.proj = nn.ModuleList([
                Linear(args.n_units, args.n_projs)
                for _ in range(args.n_layers)
            ])
        rnn_idim = args.emb_dim + args.n_units_null_context
        for l in range(args.n_layers):
            self.rnn += [
                rnn(rnn_idim,
                    args.n_units,
                    1,
                    bias=True,
                    batch_first=True,
                    dropout=0,
                    bidirectional=False)
            ]
            rnn_idim = args.n_units
            if args.n_projs > 0:
                rnn_idim = args.n_projs

        if self.use_glu:
            self.fc_glu = Linear(rnn_idim,
                                 rnn_idim * 2,
                                 dropout=args.dropout_hidden)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                rnn_idim,
                self.vocab,
                # cutoffs=[self.vocab // 10, 3 * self.vocab // 10],
                cutoffs=[self.vocab // 25, self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = Linear(rnn_idim,
                                 self.vocab,
                                 dropout=args.dropout_out)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                if args.n_units != args.emb_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(args.param_init)

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])
Exemplo n.º 5
0
    def __init__(self, args):

        super(ModelBase, self).__init__()

        self.emb_dim = args.emb_dim
        self.rnn_type = args.lm_type
        assert args.lm_type in ['lstm', 'gru']
        self.n_units = args.n_units
        self.n_layers = args.n_layers
        self.residual = args.residual
        self.use_glu = args.use_glu

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(vocab=self.vocab,
                               emb_dim=args.emb_dim,
                               dropout=args.dropout_emb,
                               ignore_index=self.pad)

        self.fast_impl = False
        if args.n_projs == 0 and not args.residual:
            self.fast_impl = True
            if 'lstm' in args.lm_type:
                rnn = nn.LSTM
            elif 'gru' in args.lm_type:
                rnn = nn.GRU
            else:
                raise ValueError('rnn_type must be "(b)lstm" or "(b)gru".')

            self.rnn = rnn(args.emb_dim, args.n_units, args.n_layers,
                           bias=True,
                           batch_first=True,
                           dropout=args.dropout_hidden,
                           bidirectional=False)
            # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer
            rnn_idim = args.n_units
            self.dropout_top = nn.Dropout(p=args.dropout_hidden)
        else:
            self.rnn = torch.nn.ModuleList()
            self.dropout = torch.nn.ModuleList()
            if args.n_projs > 0:
                self.proj = torch.nn.ModuleList()
            rnn_idim = args.emb_dim
            for l in range(args.n_layers):
                self.rnn += [getattr(nn, args.lm_type.upper())(
                    rnn_idim, args.n_units, 1,
                    bias=True,
                    batch_first=True,
                    dropout=0,
                    bidirectional=False)]
                self.dropout += [nn.Dropout(p=args.dropout_hidden)]
                rnn_idim = args.n_units

                if l != self.n_layers - 1 and args.n_projs > 0:
                    self.proj += [LinearND(rnn_idim, args.n_projs)]
                    rnn_idim = args.n_projs

        if self.use_glu:
            self.fc_glu = LinearND(rnn_idim, rnn_idim * 2,
                                   dropout=args.dropout_hidden)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                rnn_idim, self.vocab,
                # cutoffs=[self.vocab // 10, 3 * self.vocab // 10],
                cutoffs=[self.vocab // 25, self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(rnn_idim, self.vocab,
                                   dropout=args.dropout_out)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                if args.n_units != args.emb_dim:
                    raise ValueError('When using the tied flag, n_units must be equal to emb_dim.')
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(args.param_init)

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init, dist='orthogonal',
                                  keys=['rnn', 'weight'])
Exemplo n.º 6
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger = logging.getLogger('training')
        logger.info(self.__class__.__name__)

        self.save_path = save_path

        self.d_model = args.d_model
        self.d_ff = args.d_ff
        self.pe_type = args.pe_type
        self.n_layers = args.n_layers
        self.n_heads = args.attn_n_heads
        self.tie_embedding = args.tie_embedding

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # self.lsm_prob = lsm_prob

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(
            vocab=self.vocab,
            emb_dim=self.d_model,
            dropout=0,  # NOTE: do not apply dropout here
            ignore_index=self.pad)
        self.pos_enc = PositionalEncoding(args.d_model, args.dropout_emb,
                                          args.pe_type)

        self.layers = nn.ModuleList([
            TransformerDecoderBlock(args.d_model,
                                    args.d_ff,
                                    args.attn_type,
                                    args.attn_n_heads,
                                    args.dropout_hidden,
                                    args.dropout_att,
                                    args.layer_norm_eps,
                                    src_attention=False)
            for _ in range(self.n_layers)
        ])
        self.norm_out = nn.LayerNorm(args.d_model, eps=args.layer_norm_eps)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                args.d_model,
                self.vocab,
                cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)],
                # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(self.d_model, self.vocab)

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters()
Exemplo n.º 7
0
    def __init__(self, args):

        super(ModelBase, self).__init__()

        self.emb_dim = args.emb_dim
        self.n_units = args.n_units
        self.n_layers = args.n_layers
        self.tie_embedding = args.tie_embedding
        self.backward = args.backward

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(vocab=self.vocab,
                               emb_dim=args.emb_dim,
                               dropout=args.dropout_emb,
                               ignore_index=self.pad)

        layers = OrderedDict()

        model_size = args.lm_type.replace('gated_conv_', '')

        if model_size == 'small':
            layers['conv1-1'] = GLUBlock(4,
                                         args.emb_dim,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            layers['conv2-1'] = GLUBlock(4,
                                         600,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            layers['conv3-1'] = GLUBlock(4,
                                         600,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            layers['conv4-1'] = GLUBlock(4,
                                         600,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            layers['conv5-1'] = GLUBlock(4,
                                         600,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            last_dim = 600

        elif model_size == '8':
            layers['conv1-1'] = GLUBlock(4,
                                         args.emb_dim,
                                         900,
                                         dropout=args.dropout_hidden)
            for i in range(1, 8, 1):
                layers['conv2-%d' % i] = GLUBlock(4,
                                                  900,
                                                  900,
                                                  dropout=args.dropout_hidden)
            last_dim = 900

        elif model_size == '8B':
            raise NotImplementedError

        elif model_size == '9':
            raise NotImplementedError

        elif model_size == '13':
            layers['conv1-1'] = GLUBlock(4,
                                         args.emb_dim,
                                         1268,
                                         dropout=args.dropout_hidden)
            for i in range(1, 13, 1):
                layers['conv2-%d' % i] = GLUBlock(4,
                                                  1268,
                                                  1268,
                                                  dropout=args.dropout_hidden)
            last_dim = 1268

        elif model_size == '14':
            for i in range(1, 4, 1):
                layers['conv1-%d' % i] = GLUBlock(
                    6,
                    args.emb_dim if i == 1 else 850,
                    850,
                    dropout=args.dropout_hidden)
            layers['conv2-1'] = GLUBlock(1,
                                         850,
                                         850,
                                         dropout=args.dropout_hidden)
            for i in range(1, 5, 1):
                layers['conv3-%d' % i] = GLUBlock(5,
                                                  850,
                                                  850,
                                                  dropout=args.dropout_hidden)
            layers['conv4-1'] = GLUBlock(1,
                                         850,
                                         850,
                                         dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                layers['conv5-%d' % i] = GLUBlock(4,
                                                  850,
                                                  850,
                                                  dropout=args.dropout_hidden)
            layers['conv6-1'] = GLUBlock(4,
                                         850,
                                         1024,
                                         dropout=args.dropout_hidden)
            layers['conv7-1'] = GLUBlock(4,
                                         1024,
                                         2048,
                                         dropout=args.dropout_hidden)
            last_dim = 2048

        elif model_size == '14B':
            layers['conv1-1'] = GLUBlock(5,
                                         args.emb_dim,
                                         512,
                                         dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                layers['conv2-%d' % i] = GLUBlock(5,
                                                  512,
                                                  512,
                                                  bottlececk_dim=128,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                layers['conv3-%d' % i] = GLUBlock(5,
                                                  512 if i == 1 else 1024,
                                                  1024,
                                                  bottlececk_dim=512,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 7, 1):
                layers['conv4-%d' % i] = GLUBlock(5,
                                                  1024 if i == 1 else 2048,
                                                  2048,
                                                  bottlececk_dim=1024,
                                                  dropout=args.dropout_hidden)
            layers['conv5-1'] = GLUBlock(5,
                                         2048,
                                         4096,
                                         bottlececk_dim=1024,
                                         dropout=args.dropout_hidden)
            last_dim = 4096

        else:
            raise NotImplementedError(model_size)

        self.layers = nn.Sequential(layers)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                last_dim,
                self.vocab,
                # cutoffs=[self.vocab // 10, 3 * self.vocab // 10],
                cutoffs=[self.vocab // 25, self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(last_dim,
                                   self.vocab,
                                   dropout=args.dropout_out)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                if args.n_units != args.emb_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(args.param_init, dist=args.param_init_dist)

        # Initialize bias vectors with zero
        self.reset_parameters(0, dist='constant', keys=['bias'])
Exemplo n.º 8
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 pe_type,
                 tie_embedding,
                 vocab,
                 dropout=0.0,
                 dropout_emb=0.0,
                 dropout_att=0.0,
                 lsm_prob=0.0,
                 layer_norm_eps=1e-6,
                 ctc_weight=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 adaptive_softmax=False):

        super(TransformerDecoder, self).__init__()

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.ctc_fc_list = ctc_fc_list
        self.backward = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0:
            # Fully-connected layers for CTC
            if len(ctc_fc_list) > 0:
                fc_layers = OrderedDict()
                for i in range(len(ctc_fc_list)):
                    input_dim = d_model if i == 0 else ctc_fc_list[i - 1]
                    fc_layers['fc' + str(i)] = LinearND(input_dim, ctc_fc_list[i], dropout=dropout)
                fc_layers['fc' + str(len(ctc_fc_list))] = LinearND(ctc_fc_list[-1], vocab, dropout=0)
                self.output_ctc = nn.Sequential(fc_layers)
            else:
                self.output_ctc = LinearND(d_model, vocab)
            self.decode_ctc_greedy = GreedyDecoder(blank=blank)
            self.decode_ctc_beam = BeamSearchDecoder(blank=blank)
            import warpctc_pytorch
            self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True)

        if ctc_weight < global_weight:
            self.layers = nn.ModuleList(
                [TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                         dropout, dropout_att, layer_norm_eps)
                 for _ in range(n_layers)])

            self.embed = Embedding(vocab, d_model,
                                   dropout=0,  # NOTE: do not apply dropout here
                                   ignore_index=pad)
            if pe_type:
                self.pos_emb_out = PositionalEncoding(d_model, dropout_emb, pe_type)

            if adaptive_softmax:
                self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                    d_model, vocab,
                    cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)],
                    # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                    div_value=4.0)
                self.output = None
            else:
                self.adaptive_softmax = None
                self.output = LinearND(d_model, vocab)

                # Optionally tie weights as in:
                # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
                # https://arxiv.org/abs/1608.05859
                # and
                # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
                # https://arxiv.org/abs/1611.01462
                if tie_embedding:
                    self.output.fc.weight = self.embed.embed.weight

            self.norm_top = nn.LayerNorm(d_model, eps=layer_norm_eps)

        # Initialize parameters
        self.reset_parameters()
Exemplo n.º 9
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 rnn_type,
                 n_units,
                 n_projs,
                 n_layers,
                 bottleneck_dim,
                 emb_dim,
                 vocab,
                 tie_embedding=False,
                 attn_conv_kernel_size=0,
                 dropout=0.0,
                 dropout_emb=0.0,
                 lsm_prob=0.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 lm_fusion=None,
                 lm_fusion_type='cold',
                 discourse_aware='',
                 lm_init=None,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 param_init=0.1,
                 replace_sos=False,
                 soft_label_weight=0.0):

        super(CIFRNNDecoder, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.vocab = vocab
        self.rnn_type = rnn_type
        assert rnn_type in ['lstm', 'gru']
        self.enc_n_units = enc_n_units
        self.dec_n_units = n_units
        self.n_projs = n_projs
        self.n_layers = n_layers
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.lm_fusion_type = lm_fusion_type
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch
        self.replace_sos = replace_sos
        self.soft_label_weight = soft_label_weight

        self.quantity_loss_weight = 1.0

        # for contextualization
        self.discourse_aware = discourse_aware
        self.dstate_prev = None

        # for cache
        self.prev_spk = ''
        self.total_step = 0
        self.dstates_final = None
        self.lmstate_final = None

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=param_init)

        if ctc_weight < global_weight:
            # Attention layer
            self.score = CIF(enc_dim=self.enc_n_units,
                             conv_kernel_size=attn_conv_kernel_size,
                             conv_out_channels=self.enc_n_units)

            # Decoder
            self.rnn = nn.ModuleList()
            if self.n_projs > 0:
                self.proj = nn.ModuleList(
                    [Linear(n_units, n_projs) for _ in range(n_layers)])
            self.dropout = nn.ModuleList(
                [nn.Dropout(p=dropout) for _ in range(n_layers)])
            rnn = nn.LSTM if rnn_type == 'lstm' else nn.GRU
            dec_odim = enc_n_units + emb_dim
            for l in range(n_layers):
                self.rnn += [rnn(dec_odim, n_units, 1)]
                dec_odim = n_units
                if self.n_projs > 0:
                    dec_odim = n_projs

            # LM fusion
            if lm_fusion is not None:
                self.linear_dec_feat = Linear(dec_odim + enc_n_units, n_units)
                if lm_fusion_type in ['cold', 'deep']:
                    self.linear_lm_feat = Linear(lm_fusion.n_units, n_units)
                    self.linear_lm_gate = Linear(n_units * 2, n_units)
                elif lm_fusion_type == 'cold_prob':
                    self.linear_lm_feat = Linear(lm_fusion.vocab, n_units)
                    self.linear_lm_gate = Linear(n_units * 2, n_units)
                else:
                    raise ValueError(lm_fusion_type)
                self.output_bn = Linear(n_units * 2, bottleneck_dim)

                # fix LM parameters
                for p in lm_fusion.parameters():
                    p.requires_grad = False
            elif discourse_aware == 'hierarchical':
                raise NotImplementedError
            else:
                self.output_bn = Linear(dec_odim + enc_n_units, bottleneck_dim)

            self.embed = Embedding(vocab,
                                   emb_dim,
                                   dropout=dropout_emb,
                                   ignore_index=pad)

            self.output = Linear(bottleneck_dim, vocab)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if tie_embedding:
                if emb_dim != bottleneck_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(param_init)

        # resister the external LM
        self.lm = lm_fusion

        # decoder initialization with pre-trained LM
        if lm_init is not None:
            assert lm_init.vocab == vocab
            assert lm_init.n_units == n_units
            assert lm_init.emb_dim == emb_dim
            logger.info('===== Initialize the decoder with pre-trained RNNLM')
            assert lm_init.n_projs == 0  # TODO(hirofumi): fix later
            assert lm_init.n_units_null_context == enc_n_units

            # RNN
            for l in range(lm_init.n_layers):
                for n, p in lm_init.rnn[l].named_parameters():
                    assert getattr(self.rnn[l], n).size() == p.size()
                    getattr(self.rnn[l], n).data = p.data
                    logger.info('Overwrite %s' % n)

            # embedding
            assert self.embed.embed.weight.size(
            ) == lm_init.embed.embed.weight.size()
            self.embed.embed.weight.data = lm_init.embed.embed.weight.data
            logger.info('Overwrite %s' % 'embed.embed.weight')
Exemplo n.º 10
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 rnn_type,
                 n_units,
                 n_projs,
                 n_layers,
                 residual,
                 bottleneck_dim,
                 emb_dim,
                 vocab,
                 tie_embedding=False,
                 dropout=0.0,
                 dropout_emb=0.0,
                 lsm_prob=0.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 lm_init=None,
                 lmobj_weight=0.0,
                 share_lm_softmax=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 param_init=0.1,
                 start_pointing=False,
                 end_pointing=True):

        super(RNNTransducer, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.vocab = vocab
        self.rnn_type = rnn_type
        assert rnn_type in ['lstm_transducer', 'gru_transducer']
        self.enc_n_units = enc_n_units
        self.dec_n_units = n_units
        self.n_projs = n_projs
        self.n_layers = n_layers
        self.residual = residual
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.lmobj_weight = lmobj_weight
        self.share_lm_softmax = share_lm_softmax
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        # VAD
        self.start_pointing = start_pointing
        self.end_pointing = end_pointing

        # for cache
        self.prev_spk = ''
        self.lmstate_final = None
        self.state_cache = OrderedDict()

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=param_init)

        if ctc_weight < global_weight:
            import warprnnt_pytorch
            self.warprnnt_loss = warprnnt_pytorch.RNNTLoss()

            # for MTL with LM objective
            if lmobj_weight > 0:
                if share_lm_softmax:
                    self.output_lmobj = self.output  # share paramters
                else:
                    self.output_lmobj = Linear(n_units, vocab)

            # Prediction network
            self.fast_impl = False
            rnn = nn.LSTM if rnn_type == 'lstm_transducer' else nn.GRU
            if n_projs == 0 and not residual:
                self.fast_impl = True
                self.rnn = rnn(emb_dim, n_units, n_layers,
                               bias=True,
                               batch_first=True,
                               dropout=dropout,
                               bidirectional=False)
                # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer
                dec_idim = n_units
                self.dropout_top = nn.Dropout(p=dropout)
            else:
                self.rnn = nn.ModuleList()
                self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(n_layers)])
                if n_projs > 0:
                    self.proj = nn.ModuleList([Linear(dec_idim, n_projs) for _ in range(n_layers)])
                dec_idim = emb_dim
                for l in range(n_layers):
                    self.rnn += [rnn(dec_idim, n_units, 1,
                                     bias=True,
                                     batch_first=True,
                                     dropout=0,
                                     bidirectional=False)]
                    dec_idim = n_projs if n_projs > 0 else n_units

            self.embed = Embedding(vocab, emb_dim,
                                   dropout=dropout_emb,
                                   ignore_index=pad)

            self.w_enc = Linear(enc_n_units, bottleneck_dim, bias=True)
            self.w_dec = Linear(dec_idim, bottleneck_dim, bias=False)
            self.output = Linear(bottleneck_dim, vocab)

        # Initialize parameters
        self.reset_parameters(param_init)

        # prediction network initialization with pre-trained LM
        if lm_init is not None:
            assert lm_init.vocab == vocab
            assert lm_init.n_units == n_units
            assert lm_init.n_projs == n_projs
            assert lm_init.n_layers == n_layers
            assert lm_init.residual == residual

            param_dict = dict(lm_init.named_parameters())
            for n, p in self.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)
Exemplo n.º 11
0
    def __init__(self, args):

        super(ModelBase, self).__init__()

        # for encoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru']:
            self.enc_n_units *= 2
        self.bridge_layer = args.bridge_layer

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for attention layer
        self.attn_n_heads = args.attn_n_heads

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # Feature extraction
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0)

        # Encoder
        if args.enc_type == 'transformer':
            self.enc = TransformerEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                attn_type=args.transformer_attn_type,
                attn_n_heads=args.transformer_attn_n_heads,
                n_layers=args.transformer_enc_n_layers,
                d_model=args.d_model,
                d_ff=args.d_ff,
                # pe_type=args.pe_type,
                pe_type=False,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                dropout_att=args.dropout_att,
                layer_norm_eps=args.layer_norm_eps,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim)
        else:
            self.enc = RNNEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                rnn_type=args.enc_type,
                n_units=args.enc_n_units,
                n_projs=args.enc_n_projs,
                n_layers=args.enc_n_layers,
                n_layers_sub1=args.enc_n_layers_sub1,
                n_layers_sub2=args.enc_n_layers_sub2,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                subsample=list(map(int, args.subsample.split('_'))) + [1] *
                (args.enc_n_layers - len(args.subsample.split('_'))),
                subsample_type=args.subsample_type,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim,
                residual=args.enc_residual,
                nin=args.enc_nin,
                task_specific_layer=args.task_specific_layer)
            # NOTE: pure CNN/TDS encoders are also included

        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # Bridge layer between the encoder and decoder
        self.is_bridge = False
        if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer']
                and args.ctc_weight < 1
            ) or args.dec_type == 'transformer' or args.bridge_layer:
            self.bridge = LinearND(self.enc.output_dim,
                                   args.d_model if args.dec_type
                                   == 'transformer' else args.dec_n_units,
                                   dropout=args.dropout_enc)
            self.is_bridge = True
            if self.sub1_weight > 0:
                self.bridge_sub1 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            if self.sub2_weight > 0:
                self.bridge_sub2 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            self.enc_n_units = args.dec_n_units

        # main task
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Cold fusion
            if args.lm_fusion and dir == 'fwd':
                lm = RNNLM(args.lm_conf)
                lm, _ = load_checkpoint(lm, args.lm_fusion)
            else:
                args.lm_conf = False
                lm = None
            # TODO(hirofumi): cold fusion for backward RNNLM

            # Decoder
            if args.dec_type == 'transformer':
                dec = TransformerDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.transformer_attn_type,
                    attn_n_heads=args.transformer_attn_n_heads,
                    n_layers=args.transformer_dec_n_layers,
                    d_model=args.d_model,
                    d_ff=args.d_ff,
                    pe_type=args.pe_type,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    lsm_prob=args.lsm_prob,
                    layer_norm_eps=args.layer_norm_eps,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    backward=(dir == 'bwd'),
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch)
            else:
                dec = RNNDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_n_channels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_n_heads=args.attn_n_heads,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    loop_type=args.dec_loop_type,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    ss_prob=args.ss_prob,
                    ss_type=args.ss_type,
                    lsm_prob=args.lsm_prob,
                    fl_weight=args.focal_loss_weight,
                    fl_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    input_feeding=args.input_feeding,
                    backward=(dir == 'bwd'),
                    # lm=args.lm_conf,
                    lm=lm,  # TODO(hirofumi): load RNNLM in the model init.
                    lm_fusion_type=args.lm_fusion_type,
                    contextualize=args.contextualize,
                    lm_init=args.lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    adaptive_softmax=args.adaptive_softmax)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                if args.dec_type == 'transformer':
                    raise NotImplementedError
                else:
                    dec_sub = RNNDecoder(
                        eos=self.eos,
                        unk=self.unk,
                        pad=self.pad,
                        blank=self.blank,
                        enc_n_units=self.enc_n_units,
                        attn_type=args.attn_type,
                        attn_dim=args.attn_dim,
                        attn_sharpening_factor=args.attn_sharpening,
                        attn_sigmoid_smoothing=args.attn_sigmoid,
                        attn_conv_out_channels=args.attn_conv_n_channels,
                        attn_conv_kernel_size=args.attn_conv_width,
                        attn_n_heads=1,
                        rnn_type=args.dec_type,
                        n_units=args.dec_n_units,
                        n_projs=args.dec_n_projs,
                        n_layers=args.dec_n_layers,
                        loop_type=args.dec_loop_type,
                        residual=args.dec_residual,
                        bottleneck_dim=args.dec_bottleneck_dim,
                        emb_dim=args.emb_dim,
                        tie_embedding=args.tie_embedding,
                        vocab=getattr(self, 'vocab_' + sub),
                        dropout=args.dropout_dec,
                        dropout_emb=args.dropout_emb,
                        dropout_att=args.dropout_att,
                        ss_prob=args.ss_prob,
                        ss_type=args.ss_type,
                        lsm_prob=args.lsm_prob,
                        fl_weight=args.focal_loss_weight,
                        fl_gamma=args.focal_loss_gamma,
                        ctc_weight=getattr(self, 'ctc_weight_' + sub),
                        ctc_fc_list=[
                            int(fc) for fc in getattr(args, 'ctc_fc_list_' +
                                                      sub).split('_')
                        ] if getattr(args, 'ctc_fc_list_' + sub) is not None
                        and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else
                        [],
                        input_feeding=args.input_feeding,
                        global_weight=getattr(self, sub + '_weight'),
                        mtl_per_batch=args.mtl_per_batch)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed_in = dec.embed
            else:
                self.embed_in = Embedding(vocab=args.vocab_sub1,
                                          emb_dim=args.emb_dim,
                                          dropout=args.dropout_emb,
                                          ignore_index=self.pad)

        # Initialize parameters in CNN layers
        self.reset_parameters(
            args.param_init,
            #   dist='xavier_uniform',
            #   dist='kaiming_uniform',
            dist='lecun',
            keys=['conv'],
            ignore_keys=['score'])

        # Initialize parameters in the encoder
        if args.enc_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['enc'],
                                  ignore_keys=['embed_in'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed_in'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['enc'],
                                  ignore_keys=['conv'])

        # Initialize parameters in the decoder
        if args.dec_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['dec'],
                                  ignore_keys=['embed'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['dec'])

        # Initialize bias vectors with zero
        self.reset_parameters(0, dist='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Initialize bias in gating with -1 for cold fusion
        if args.lm_fusion:
            self.reset_parameters(-1,
                                  dist='constant',
                                  keys=['linear_lm_gate.fc.bias'])

        if args.lm_fusion_type == 'deep' and args.lm_fusion:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False