Ejemplo n.º 1
0
def build_lm(args, save_path=None, wordlm=False, lm_dict_path=None, asr_dict_path=None):
    """Select LM class.

    Args:
        args ():
        save_path (str):
        wordlm (bool):
        lm_dict_path (dict):
        asr_dict_path (dict):
    Returns:
        lm ():

    """
    if 'gated_conv' in args.lm_type:
        lm = GatedConvLM(args, save_path)
    elif args.lm_type == 'transformer':
        lm = TransformerLM(args, save_path)
    else:
        lm = RNNLM(args, save_path)

        # Word-level RNNLM
        # if wordlm:
        #     lm = LookAheadWordLM(lm, lm_dict_path, asr_dict_path)

    return lm
Ejemplo n.º 2
0
def build_lm(args, save_path=None, wordlm=False, lm_dict_path=None, asr_dict_path=None):
    """Select LM class.

    Args:
        args ():
        save_path (str):
        wordlm (bool):
        lm_dict_path (dict):
        asr_dict_path (dict):
    Returns:
        lm ():

    """
    if 'gated_conv' in args.lm_type:
        from neural_sp.models.lm.gated_convlm import GatedConvLM
        lm = GatedConvLM(args, save_path)
    elif args.lm_type == 'transformer':
        from neural_sp.models.lm.transformerlm import TransformerLM
        lm = TransformerLM(args, save_path)
    elif args.lm_type == 'transformer_xl':
        from neural_sp.models.lm.transformer_xl import TransformerXL
        lm = TransformerXL(args, save_path)
    else:
        from neural_sp.models.lm.rnnlm import RNNLM
        lm = RNNLM(args, save_path)

    return lm
Ejemplo n.º 3
0
def select_lm(args, save_path=None):
    if args.lm_type == 'gated_cnn':
        lm = GatedConvLM(args, save_path)
    elif args.lm_type == 'transformer':
        lm = TransformerLM(args, save_path)
    else:
        lm = RNNLM(args, save_path)
    return lm
Ejemplo n.º 4
0
def build_lm(args, save_path=None):
    """Select LM class.

    Args:
        args ():
        save_path (str):
        wordlm (bool):
        lm_dict_path (dict):
        asr_dict_path (dict):
    Returns:
        lm ():

    """
    if 'gated_conv' in args.lm_type:
        lm = GatedConvLM(args, save_path)
    elif args.lm_type == 'transformer':
        lm = TransformerLM(args, save_path)
    else:
        lm = RNNLM(args, save_path)
    return lm
Ejemplo n.º 5
0
    def __init__(self, args, save_path=None, idx2token=None):

        super(ModelBase, self).__init__()

        self.save_path = save_path

        # for encoder, decoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru', 'conv_blstm', 'conv_bgru']:
            self.enc_n_units *= 2
        self.dec_type = args.dec_type

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # for MBR
        self.mbr_training = args.mbr_training
        self.recog_params = vars(args)
        self.idx2token = idx2token

        # Feature extraction
        self.gaussian_noise = args.gaussian_noise
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.use_specaug = args.n_freq_masks > 0 or args.n_time_masks > 0
        self.specaug = None
        self.flip_time_prob = args.flip_time_prob
        self.flip_freq_prob = args.flip_freq_prob
        self.weight_noise = args.weight_noise
        if self.use_specaug:
            assert args.n_stacks == 1 and args.n_skips == 1
            assert args.n_splices == 1
            self.specaug = SpecAugment(F=args.freq_width,
                                       T=args.time_width,
                                       n_freq_masks=args.n_freq_masks,
                                       n_time_masks=args.n_time_masks,
                                       p=args.time_width_upper)

        # Frontend
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0,
                                              param_init=args.param_init)

        # Encoder
        self.enc = build_encoder(args)
        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # main task
        external_lm = None
        directions = []
        if self.fwd_weight > 0 or (self.bwd_weight == 0
                                   and self.ctc_weight > 0):
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Load the LM for LM fusion and decoder initialization
            if args.external_lm and dir == 'fwd':
                external_lm = RNNLM(args.lm_conf)
                load_checkpoint(external_lm, args.external_lm)
                # freeze LM parameters
                for n, p in external_lm.named_parameters():
                    p.requires_grad = False

            # Decoder
            special_symbols = {
                'blank': self.blank,
                'unk': self.unk,
                'eos': self.eos,
                'pad': self.pad,
            }
            dec = build_decoder(
                args, special_symbols, self.enc.output_dim, args.vocab,
                self.ctc_weight, args.ctc_fc_list, self.main_weight -
                self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                external_lm)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                dec_sub = build_decoder(args, special_symbols,
                                        self.enc.output_dim,
                                        getattr(self, 'vocab_' + sub),
                                        getattr(self, 'ctc_weight_' + sub),
                                        getattr(args, 'ctc_fc_list_' + sub),
                                        getattr(self,
                                                sub + '_weight'), external_lm)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed = dec.embed
            else:
                self.embed = nn.Embedding(args.vocab_sub1,
                                          args.emb_dim,
                                          padding_idx=self.pad)
                self.dropout_emb = nn.Dropout(p=args.dropout_emb)

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Fix all parameters except for the gating parts in deep fusion
        if args.lm_fusion == 'deep' and external_lm is not None:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
Ejemplo n.º 6
0
    def __init__(self, args, save_path=None):

        super(ModelBase, self).__init__()

        self.save_path = save_path

        # for encoder, decoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru', 'conv_blstm', 'conv_bgru']:
            self.enc_n_units *= 2
        self.dec_type = args.dec_type

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # Feature extraction
        self.gaussian_noise = args.gaussian_noise
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.is_specaug = args.n_freq_masks > 0 or args.n_time_masks > 0
        self.specaug = None
        if self.is_specaug:
            assert args.n_stacks == 1 and args.n_skips == 1
            assert args.n_splices == 1
            self.specaug = SpecAugment(F=args.freq_width,
                                       T=args.time_width,
                                       n_freq_masks=args.n_freq_masks,
                                       n_time_masks=args.n_time_masks,
                                       p=args.time_width_upper)

        # Frontend
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0,
                                              param_init=args.param_init)

        # Encoder
        self.enc = select_encoder(args)
        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # main task
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Load the LM for LM fusion
            if args.lm_fusion and dir == 'fwd':
                lm_fusion = RNNLM(args.lm_conf)
                lm_fusion, _ = load_checkpoint(lm_fusion, args.lm_fusion)
            else:
                lm_fusion = None
                # TODO(hirofumi): for backward RNNLM

            # Load the LM for LM initialization
            if args.lm_init and dir == 'fwd':
                lm_init = RNNLM(args.lm_conf)
                lm_init, _ = load_checkpoint(lm_init, args.lm_init)
            else:
                lm_init = None
                # TODO(hirofumi): for backward RNNLM

            # Decoder
            if args.dec_type == 'transformer':
                dec = TransformerDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.transformer_attn_type,
                    attn_n_heads=args.transformer_attn_n_heads,
                    n_layers=args.dec_n_layers,
                    d_model=args.d_model,
                    d_ff=args.d_ff,
                    vocab=self.vocab,
                    tie_embedding=args.tie_embedding,
                    pe_type=args.pe_type,
                    layer_norm_eps=args.layer_norm_eps,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    lsm_prob=args.lsm_prob,
                    focal_loss_weight=args.focal_loss_weight,
                    focal_loss_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    backward=(dir == 'bwd'),
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch)
            elif 'transducer' in args.dec_type:
                dec = RNNTransducer(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    lsm_prob=args.lsm_prob,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    lm_init=lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    param_init=args.param_init)
            else:
                dec = RNNDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_n_channels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_n_heads=args.attn_n_heads,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    loop_type=args.dec_loop_type,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    vocab=self.vocab,
                    tie_embedding=args.tie_embedding,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    zoneout=args.zoneout,
                    ss_prob=args.ss_prob,
                    ss_type=args.ss_type,
                    lsm_prob=args.lsm_prob,
                    focal_loss_weight=args.focal_loss_weight,
                    focal_loss_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_lsm_prob=args.ctc_lsm_prob,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    input_feeding=args.input_feeding,
                    backward=(dir == 'bwd'),
                    lm_fusion=lm_fusion,
                    lm_fusion_type=args.lm_fusion_type,
                    discourse_aware=args.discourse_aware,
                    lm_init=lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    adaptive_softmax=args.adaptive_softmax,
                    param_init=args.param_init,
                    replace_sos=args.replace_sos)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                if args.dec_type == 'transformer':
                    raise NotImplementedError
                else:
                    dec_sub = RNNDecoder(
                        eos=self.eos,
                        unk=self.unk,
                        pad=self.pad,
                        blank=self.blank,
                        enc_n_units=self.enc_n_units,
                        attn_type=args.attn_type,
                        attn_dim=args.attn_dim,
                        attn_sharpening_factor=args.attn_sharpening,
                        attn_sigmoid_smoothing=args.attn_sigmoid,
                        attn_conv_out_channels=args.attn_conv_n_channels,
                        attn_conv_kernel_size=args.attn_conv_width,
                        attn_n_heads=1,
                        rnn_type=args.dec_type,
                        n_units=args.dec_n_units,
                        n_projs=args.dec_n_projs,
                        n_layers=args.dec_n_layers,
                        loop_type=args.dec_loop_type,
                        residual=args.dec_residual,
                        bottleneck_dim=args.dec_bottleneck_dim,
                        emb_dim=args.emb_dim,
                        tie_embedding=args.tie_embedding,
                        vocab=getattr(self, 'vocab_' + sub),
                        dropout=args.dropout_dec,
                        dropout_emb=args.dropout_emb,
                        dropout_att=args.dropout_att,
                        ss_prob=args.ss_prob,
                        ss_type=args.ss_type,
                        lsm_prob=args.lsm_prob,
                        focal_loss_weight=args.focal_loss_weight,
                        focal_loss_gamma=args.focal_loss_gamma,
                        ctc_weight=getattr(self, 'ctc_weight_' + sub),
                        ctc_lsm_prob=args.ctc_lsm_prob,
                        ctc_fc_list=[
                            int(fc) for fc in getattr(args, 'ctc_fc_list_' +
                                                      sub).split('_')
                        ] if getattr(args, 'ctc_fc_list_' + sub) is not None
                        and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else
                        [],
                        input_feeding=args.input_feeding,
                        global_weight=getattr(self, sub + '_weight'),
                        mtl_per_batch=args.mtl_per_batch,
                        param_init=args.param_init)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed = dec.embed
            else:
                self.embed = Embedding(vocab=args.vocab_sub1,
                                       emb_dim=args.emb_dim,
                                       dropout=args.dropout_emb,
                                       ignore_index=self.pad)

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Fix all parameters except for the gating parts in deep fusion
        if args.lm_fusion_type == 'deep' and args.lm_fusion:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
Ejemplo n.º 7
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')):
        os.remove(os.path.join(args.recog_dir, 'decode.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'decode.log'),
                        key='decoding')

    ppl_avg = 0
    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          batch_size=args.recog_batch_size,
                          bptt=args.bptt,
                          serialize=args.serialize,
                          is_test=True)

        if i == 0:
            # Load the LM
            if args.lm_type == 'gated_cnn':
                model = GatedConvLM(args)
            else:
                model = RNNLM(args)
            model, checkpoint = load_checkpoint(model, args.recog_model[0])
            epoch = checkpoint['epoch']
            model.save_path = dir_name

            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            # logger.info('recog unit: %s' % args.recog_unit)
            # logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('BPTT: %d' % (args.bptt))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache theta: %.3f' % (args.recog_cache_theta))
            logger.info('cache lambda: %.3f' % (args.recog_cache_lambda))
            model.cache_theta = args.recog_cache_theta
            model.cache_lambda = args.recog_cache_lambda

            # GPU setting
            model.cuda()

        start_time = time.time()

        # TODO(hirofumi): ensemble
        ppl, _ = eval_ppl([model],
                          dataset,
                          batch_size=1,
                          bptt=args.bptt,
                          n_caches=args.recog_n_caches,
                          progressbar=True)
        ppl_avg += ppl
        print('PPL (%s): %.2f' % (dataset.set, ppl))
        logger.info('Elasped time: %.2f [sec]:' % (time.time() - start_time))

    logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets)))
Ejemplo n.º 8
0
def main():

    args = parse()

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=args.batch_size * args.n_gpus,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    bptt=args.bptt,
                    backward=args.backward,
                    serialize=args.serialize)
        ]

    args.vocab = train_set.vocab

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = make_model_name(args)
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    logger = set_logger(os.path.join(save_path, 'train.log'), key='training')

    # Model setting
    if 'gated_conv' in args.lm_type:
        model = GatedConvLM(args)
    else:
        model = RNNLM(args)
    model.save_path = save_path

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        model.set_optimizer(
            optimizer='sgd'
            if epoch > conf['convert_to_sgd_epoch'] + 1 else conf['optimizer'],
            learning_rate=float(conf['learning_rate']),  # on-the-fly
            weight_decay=float(conf['weight_decay']))

        # Restore the last saved model
        model, checkpoint = load_checkpoint(model, args.resume, resume=True)
        lr_controller = checkpoint['lr_controller']
        epoch = checkpoint['epoch']
        step = checkpoint['step']
        ppl_dev_best = checkpoint['metric_dev_best']

        # Resume between convert_to_sgd_epoch and convert_to_sgd_epoch + 1
        if epoch == conf['convert_to_sgd_epoch'] + 1:
            model.set_optimizer(optimizer='sgd',
                                learning_rate=args.learning_rate,
                                weight_decay=float(conf['weight_decay']))
            logger.info('========== Convert to SGD ==========')
    else:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(model.save_path, 'conf.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(model.save_path,
                                                  'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(model.save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(model.save_path,
                                                    'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            nparams = model.num_params_dict[n]
            logger.info("%s %d" % (n, nparams))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate=float(args.learning_rate),
                            weight_decay=float(args.weight_decay))

        epoch, step = 1, 1
        ppl_dev_best = 10000

        # Set learning rate controller
        lr_controller = Controller(
            learning_rate=float(args.learning_rate),
            decay_type=args.decay_type,
            decay_start_epoch=args.decay_start_epoch,
            decay_rate=args.decay_rate,
            decay_patient_n_epochs=args.decay_patient_n_epochs,
            lower_better=True,
            best_value=ppl_dev_best)

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.n_gpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    if args.job_name:
        setproctitle(args.job_name)
    else:
        setproctitle(dir_name)

    # Set reporter
    reporter = Reporter(model.module.save_path, tensorboard=True)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    pbar_epoch = tqdm(total=len(train_set))
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()

        model.module.optimizer.zero_grad()
        loss, hidden, reporter = model(ys_train, hidden, reporter)
        if len(model.device_ids) > 1:
            loss.backward(torch.ones(len(model.device_ids)))
        else:
            loss.backward()
        loss.detach()  # Trancate the graph
        if args.clip_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.module.parameters(),
                                           args.clip_grad_norm)
        model.module.optimizer.step()
        loss_train = loss.item()
        del loss
        if 'gated_conv' not in args.lm_type:
            hidden = model.module.repackage_hidden(hidden)
        reporter.step(is_eval=False)

        if step % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            loss, _, reporter = model(ys_dev, None, reporter, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)"
                % (step, train_set.epoch_detail, loss_train, loss_dev,
                   np.exp(loss_train), np.exp(loss_dev), lr_controller.lr,
                   ys_train.shape[0], duration_step / 60))
            start_time_step = time.time()
        step += args.n_gpus
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))

        # Save fugures of loss and accuracy
        if step % (args.print_step * 10) == 0:
            reporter.snapshot()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (epoch, duration_epoch / 60))

            if epoch < args.eval_start_epoch:
                # Save the model
                save_checkpoint(model.module,
                                model.module.save_path,
                                lr_controller,
                                epoch,
                                step - 1,
                                ppl_dev_best,
                                remove_old_checkpoints=True)
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                logger.info('PPL (%s): %.2f' % (dev_set.set, ppl_dev))

                # Update learning rate
                model.module.optimizer = lr_controller.decay(
                    model.module.optimizer, epoch=epoch, value=ppl_dev)

                if ppl_dev < ppl_dev_best:
                    ppl_dev_best = ppl_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    save_checkpoint(model.module,
                                    model.module.save_path,
                                    lr_controller,
                                    epoch,
                                    step - 1,
                                    ppl_dev_best,
                                    remove_old_checkpoints=True)

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        logger.info('PPL (%s): %.2f' %
                                    (eval_set.set, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg.): %.2f' %
                                    (ppl_test_avg / len(eval_sets)))
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_n_epochs:
                    break

                # Convert to fine-tuning stage
                if epoch == args.convert_to_sgd_epoch:
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate=args.learning_rate,
                        weight_decay=float(args.weight_decay))
                    lr_controller = Controller(
                        learning_rate=args.learning_rate,
                        decay_type='epoch',
                        decay_start_epoch=epoch,
                        decay_rate=0.5,
                        lower_better=True)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if epoch == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    if reporter.tensorboard:
        reporter.tf_writer.close()
    pbar_epoch.close()

    return model.module.save_path
Ejemplo n.º 9
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)
    recog_params = vars(args)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')):
        os.remove(os.path.join(args.recog_dir, 'decode.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'decode.log'),
                        key='decoding')

    skip_thought = 'skip' in args.enc_type

    wer_avg, cer_avg, per_avg = 0, 0, 0
    ppl_avg, loss_avg = 0, 0
    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(
            corpus=args.corpus,
            tsv_path=s,
            dict_path=os.path.join(dir_name, 'dict.txt'),
            dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub1.txt')) else False,
            dict_path_sub2=os.path.join(dir_name, 'dict_sub2.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub2.txt')) else False,
            nlsyms=os.path.join(dir_name, 'nlsyms.txt'),
            wp_model=os.path.join(dir_name, 'wp.model'),
            wp_model_sub1=os.path.join(dir_name, 'wp_sub1.model'),
            wp_model_sub2=os.path.join(dir_name, 'wp_sub2.model'),
            unit=args.unit,
            unit_sub1=args.unit_sub1,
            unit_sub2=args.unit_sub2,
            batch_size=args.recog_batch_size,
            skip_thought=skip_thought,
            is_test=True)

        if i == 0:
            # Load the ASR model
            if skip_thought:
                model = SkipThought(args)
            else:
                model = Seq2seq(args)
            model, checkpoint = load_checkpoint(model, args.recog_model[0])
            epoch = checkpoint['epoch']
            model.save_path = dir_name

            # ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    # Load a conf file
                    conf_e = load_config(
                        os.path.join(os.path.dirname(recog_model_e),
                                     'conf.yml'))

                    # Overwrite conf
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)

                    model_e = Seq2seq(args_e)
                    model_e, _ = load_checkpoint(model_e, recog_model_e)
                    model_e.cuda()
                    ensemble_models += [model_e]

            # For shallow fusion
            if not args.lm_fusion:
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    # Load a LM conf file
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm),
                                     'conf.yml'))

                    # Merge conf with args
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)

                    # Load the pre-trianed LM
                    if args_lm.lm_type == 'gated_cnn':
                        lm = GatedConvLM(args_lm)
                    else:
                        lm = RNNLM(args_lm)
                    lm, _ = load_checkpoint(lm, args.recog_lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm

                if args.recog_lm_bwd is not None and args.recog_lm_weight > 0 \
                        and (args.recog_fwd_bwd_attention or args.recog_reverse_lm_rescoring):
                    # Load a LM conf file
                    conf_lm = load_config(
                        os.path.join(args.recog_lm_bwd, 'conf.yml'))

                    # Merge conf with args
                    args_lm_bwd = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm_bwd, k, v)

                    # Load the pre-trianed LM
                    if args_lm_bwd.lm_type == 'gated_cnn':
                        lm_bwd = GatedConvLM(args_lm_bwd)
                    else:
                        lm_bwd = RNNLM(args_lm_bwd)
                    lm_bwd, _ = load_checkpoint(lm_bwd, args.recog_lm_bwd)
                    model.lm_bwd = lm_bwd

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog metric: %s' % args.recog_metric)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' %
                        args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('LM path: %s' % args.recog_lm)
            logger.info('LM path (bwd): %s' % args.recog_lm_bwd)
            logger.info('LM weight: %.3f' % args.recog_lm_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' %
                        args.recog_fwd_bwd_attention)
            logger.info('reverse LM rescoring: %s' %
                        args.recog_reverse_lm_rescoring)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' %
                        (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' %
                        (args.recog_lm_state_carry_over))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache type: %s' % (args.recog_cache_type))
            logger.info('cache word frequency threshold: %s' %
                        (args.recog_cache_word_freq))
            logger.info('cache theta (speech): %.3f' %
                        (args.recog_cache_theta_speech))
            logger.info('cache lambda (speech): %.3f' %
                        (args.recog_cache_lambda_speech))
            logger.info('cache theta (lm): %.3f' % (args.recog_cache_theta_lm))
            logger.info('cache lambda (lm): %.3f' %
                        (args.recog_cache_lambda_lm))

            # GPU setting
            model.cuda()

        start_time = time.time()

        if args.recog_metric == 'edit_distance':
            if args.recog_unit in ['word', 'word_char']:
                wer, cer, _ = eval_word(ensemble_models,
                                        dataset,
                                        recog_params,
                                        epoch=epoch - 1,
                                        recog_dir=args.recog_dir,
                                        progressbar=True)
                wer_avg += wer
                cer_avg += cer
            elif args.recog_unit == 'wp':
                wer, cer = eval_wordpiece(ensemble_models,
                                          dataset,
                                          recog_params,
                                          epoch=epoch - 1,
                                          recog_dir=args.recog_dir,
                                          progressbar=True)
                wer_avg += wer
                cer_avg += cer
            elif 'char' in args.recog_unit:
                wer, cer = eval_char(ensemble_models,
                                     dataset,
                                     recog_params,
                                     epoch=epoch - 1,
                                     recog_dir=args.recog_dir,
                                     progressbar=True,
                                     task_idx=0)
                #  task_idx=1 if args.recog_unit and 'char' in args.recog_unit else 0)
                wer_avg += wer
                cer_avg += cer
            elif 'phone' in args.recog_unit:
                per = eval_phone(ensemble_models,
                                 dataset,
                                 recog_params,
                                 epoch=epoch - 1,
                                 recog_dir=args.recog_dir,
                                 progressbar=True)[0]
                per_avg += per
            else:
                raise ValueError(args.recog_unit)
        elif args.recog_metric == 'acc':
            raise NotImplementedError
        elif args.recog_metric in ['ppl', 'loss']:
            ppl, loss = eval_ppl(ensemble_models,
                                 dataset,
                                 recog_params=recog_params,
                                 progressbar=True)
            ppl_avg += ppl
            loss_avg += loss
        elif args.recog_metric == 'bleu':
            raise NotImplementedError
        else:
            raise NotImplementedError
        logger.info('Elasped time: %.2f [sec]:' % (time.time() - start_time))

    if args.recog_metric == 'edit_distance':
        if 'phone' in args.recog_unit:
            logger.info('PER (avg.): %.2f %%\n' %
                        (per_avg / len(args.recog_sets)))
        else:
            logger.info('WER / CER (avg.): %.2f / %.2f %%\n' %
                        (wer_avg / len(args.recog_sets),
                         cer_avg / len(args.recog_sets)))
    elif args.recog_metric in ['ppl', 'loss']:
        logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets)))
        print('PPL (avg.): %.2f' % (ppl_avg / len(args.recog_sets)))
        logger.info('Loss (avg.): %.2f\n' % (loss_avg / len(args.recog_sets)))
        print('Loss (avg.): %.2f' % (loss_avg / len(args.recog_sets)))
Ejemplo n.º 10
0
    def __init__(self, args, save_path=None, idx2token=None):

        super(ModelBase, self).__init__()

        self.save_path = save_path

        # for encoder, decoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.enc_type = args.enc_type
        self.dec_type = args.dec_type

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = args.total_weight - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # for MBR
        self.mbr_training = args.mbr_training
        self.recog_params = vars(args)
        self.idx2token = idx2token

        # for discourse-aware model
        self.utt_id_prev = None

        # Feature extraction
        self.input_noise_std = args.input_noise_std
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.weight_noise_std = args.weight_noise_std
        self.specaug = None
        if args.n_freq_masks > 0 or args.n_time_masks > 0:
            assert args.n_stacks == 1 and args.n_skips == 1
            assert args.n_splices == 1
            self.specaug = SpecAugment(
                F=args.freq_width,
                T=args.time_width,
                n_freq_masks=args.n_freq_masks,
                n_time_masks=args.n_time_masks,
                p=args.time_width_upper,
                adaptive_number_ratio=args.adaptive_number_ratio,
                adaptive_size_ratio=args.adaptive_size_ratio,
                max_n_time_masks=args.max_n_time_masks)

        # Frontend
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0,
                                              param_init=args.param_init)

        # Encoder
        self.enc = build_encoder(args)
        if args.freeze_encoder:
            for n, p in self.enc.named_parameters():
                if 'bridge' in n or 'sub1' in n:
                    continue
                p.requires_grad = False
                logger.info('freeze %s' % n)

        special_symbols = {
            'blank': self.blank,
            'unk': self.unk,
            'eos': self.eos,
            'pad': self.pad,
        }

        # main task
        external_lm = None
        directions = []
        if self.fwd_weight > 0 or (self.bwd_weight == 0
                                   and self.ctc_weight > 0):
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')

        for dir in directions:
            # Load the LM for LM fusion and decoder initialization
            if args.external_lm and dir == 'fwd':
                external_lm = RNNLM(args.lm_conf)
                load_checkpoint(args.external_lm, external_lm)
                # freeze LM parameters
                for n, p in external_lm.named_parameters():
                    p.requires_grad = False

            # Decoder
            dec = build_decoder(
                args, special_symbols, self.enc.output_dim, args.vocab,
                self.ctc_weight, self.main_weight -
                self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                external_lm)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                args_sub = copy.deepcopy(args)
                if hasattr(args, 'dec_config_' + sub):
                    for k, v in getattr(args, 'dec_config_' + sub).items():
                        setattr(args_sub, k, v)
                # NOTE: Other parameters are the same as the main decoder
                dec_sub = build_decoder(args_sub, special_symbols,
                                        getattr(self.enc, 'output_dim_' + sub),
                                        getattr(self, 'vocab_' + sub),
                                        getattr(self, 'ctc_weight_' + sub),
                                        getattr(self, sub + '_weight'),
                                        external_lm)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed = dec.embed
            else:
                self.embed = nn.Embedding(args.vocab_sub1,
                                          args.emb_dim,
                                          padding_idx=self.pad)
                self.dropout_emb = nn.Dropout(p=args.dropout_emb)

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Fix all parameters except for the gating parts in deep fusion
        if args.lm_fusion == 'deep' and external_lm is not None:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
Ejemplo n.º 11
0
    def __init__(self, args):

        super(ModelBase, self).__init__()

        # for encoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru']:
            self.enc_n_units *= 2
        self.bridge_layer = args.bridge_layer

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for attention layer
        self.attn_n_heads = args.attn_n_heads

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # Feature extraction
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0)

        # Encoder
        if args.enc_type == 'transformer':
            self.enc = TransformerEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                attn_type=args.transformer_attn_type,
                attn_n_heads=args.transformer_attn_n_heads,
                n_layers=args.transformer_enc_n_layers,
                d_model=args.d_model,
                d_ff=args.d_ff,
                # pe_type=args.pe_type,
                pe_type=False,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                dropout_att=args.dropout_att,
                layer_norm_eps=args.layer_norm_eps,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim)
        else:
            self.enc = RNNEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                rnn_type=args.enc_type,
                n_units=args.enc_n_units,
                n_projs=args.enc_n_projs,
                n_layers=args.enc_n_layers,
                n_layers_sub1=args.enc_n_layers_sub1,
                n_layers_sub2=args.enc_n_layers_sub2,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                subsample=list(map(int, args.subsample.split('_'))) + [1] *
                (args.enc_n_layers - len(args.subsample.split('_'))),
                subsample_type=args.subsample_type,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim,
                residual=args.enc_residual,
                nin=args.enc_nin,
                task_specific_layer=args.task_specific_layer)
            # NOTE: pure CNN/TDS encoders are also included

        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # Bridge layer between the encoder and decoder
        self.is_bridge = False
        if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer']
                and args.ctc_weight < 1
            ) or args.dec_type == 'transformer' or args.bridge_layer:
            self.bridge = LinearND(self.enc.output_dim,
                                   args.d_model if args.dec_type
                                   == 'transformer' else args.dec_n_units,
                                   dropout=args.dropout_enc)
            self.is_bridge = True
            if self.sub1_weight > 0:
                self.bridge_sub1 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            if self.sub2_weight > 0:
                self.bridge_sub2 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            self.enc_n_units = args.dec_n_units

        # main task
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Cold fusion
            if args.lm_fusion and dir == 'fwd':
                lm = RNNLM(args.lm_conf)
                lm, _ = load_checkpoint(lm, args.lm_fusion)
            else:
                args.lm_conf = False
                lm = None
            # TODO(hirofumi): cold fusion for backward RNNLM

            # Decoder
            if args.dec_type == 'transformer':
                dec = TransformerDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.transformer_attn_type,
                    attn_n_heads=args.transformer_attn_n_heads,
                    n_layers=args.transformer_dec_n_layers,
                    d_model=args.d_model,
                    d_ff=args.d_ff,
                    pe_type=args.pe_type,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    lsm_prob=args.lsm_prob,
                    layer_norm_eps=args.layer_norm_eps,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    backward=(dir == 'bwd'),
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch)
            else:
                dec = RNNDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_n_channels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_n_heads=args.attn_n_heads,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    loop_type=args.dec_loop_type,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    ss_prob=args.ss_prob,
                    ss_type=args.ss_type,
                    lsm_prob=args.lsm_prob,
                    fl_weight=args.focal_loss_weight,
                    fl_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    input_feeding=args.input_feeding,
                    backward=(dir == 'bwd'),
                    # lm=args.lm_conf,
                    lm=lm,  # TODO(hirofumi): load RNNLM in the model init.
                    lm_fusion_type=args.lm_fusion_type,
                    contextualize=args.contextualize,
                    lm_init=args.lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    adaptive_softmax=args.adaptive_softmax)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                if args.dec_type == 'transformer':
                    raise NotImplementedError
                else:
                    dec_sub = RNNDecoder(
                        eos=self.eos,
                        unk=self.unk,
                        pad=self.pad,
                        blank=self.blank,
                        enc_n_units=self.enc_n_units,
                        attn_type=args.attn_type,
                        attn_dim=args.attn_dim,
                        attn_sharpening_factor=args.attn_sharpening,
                        attn_sigmoid_smoothing=args.attn_sigmoid,
                        attn_conv_out_channels=args.attn_conv_n_channels,
                        attn_conv_kernel_size=args.attn_conv_width,
                        attn_n_heads=1,
                        rnn_type=args.dec_type,
                        n_units=args.dec_n_units,
                        n_projs=args.dec_n_projs,
                        n_layers=args.dec_n_layers,
                        loop_type=args.dec_loop_type,
                        residual=args.dec_residual,
                        bottleneck_dim=args.dec_bottleneck_dim,
                        emb_dim=args.emb_dim,
                        tie_embedding=args.tie_embedding,
                        vocab=getattr(self, 'vocab_' + sub),
                        dropout=args.dropout_dec,
                        dropout_emb=args.dropout_emb,
                        dropout_att=args.dropout_att,
                        ss_prob=args.ss_prob,
                        ss_type=args.ss_type,
                        lsm_prob=args.lsm_prob,
                        fl_weight=args.focal_loss_weight,
                        fl_gamma=args.focal_loss_gamma,
                        ctc_weight=getattr(self, 'ctc_weight_' + sub),
                        ctc_fc_list=[
                            int(fc) for fc in getattr(args, 'ctc_fc_list_' +
                                                      sub).split('_')
                        ] if getattr(args, 'ctc_fc_list_' + sub) is not None
                        and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else
                        [],
                        input_feeding=args.input_feeding,
                        global_weight=getattr(self, sub + '_weight'),
                        mtl_per_batch=args.mtl_per_batch)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed_in = dec.embed
            else:
                self.embed_in = Embedding(vocab=args.vocab_sub1,
                                          emb_dim=args.emb_dim,
                                          dropout=args.dropout_emb,
                                          ignore_index=self.pad)

        # Initialize parameters in CNN layers
        self.reset_parameters(
            args.param_init,
            #   dist='xavier_uniform',
            #   dist='kaiming_uniform',
            dist='lecun',
            keys=['conv'],
            ignore_keys=['score'])

        # Initialize parameters in the encoder
        if args.enc_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['enc'],
                                  ignore_keys=['embed_in'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed_in'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['enc'],
                                  ignore_keys=['conv'])

        # Initialize parameters in the decoder
        if args.dec_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['dec'],
                                  ignore_keys=['embed'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['dec'])

        # Initialize bias vectors with zero
        self.reset_parameters(0, dist='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Initialize bias in gating with -1 for cold fusion
        if args.lm_fusion:
            self.reset_parameters(-1,
                                  dist='constant',
                                  keys=['linear_lm_gate.fc.bias'])

        if args.lm_fusion_type == 'deep' and args.lm_fusion:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
Ejemplo n.º 12
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)
    recog_params = vars(args)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'plot.log'), key='decoding')

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile(
                              os.path.join(dir_name, 'dict_sub1.txt')) else False,
                          nlsyms=args.nlsyms,
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          unit_sub1=args.unit_sub1,
                          batch_size=args.recog_batch_size,
                          is_test=True)

        if i == 0:
            # Load the ASR model
            model = Seq2seq(args)
            model, checkpoint = load_checkpoint(model, args.recog_model[0])
            epoch = checkpoint['epoch']
            model.save_path = dir_name

            # ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    # Load a conf file
                    conf_e = load_config(os.path.join(os.path.dirname(recog_model_e), 'conf.yml'))

                    # Overwrite conf
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)

                    model_e = Seq2seq(args_e)
                    model_e, _ = load_checkpoint(model_e, recog_model_e)
                    model_e.cuda()
                    ensemble_models += [model_e]

            # For shallow fusion
            if not args.lm_fusion:
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    # Load a LM conf file
                    conf_lm = load_config(os.path.join(os.path.dirname(args.recog_lm), 'conf.yml'))

                    # Merge conf with args
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)

                    # Load the pre-trianed LM
                    if args_lm.lm_type == 'gated_cnn':
                        lm = GatedConvLM(args_lm)
                    else:
                        lm = RNNLM(args_lm)
                    lm, _ = load_checkpoint(lm, args.recog_lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm

                if args.recog_lm_bwd is not None and args.recog_lm_weight > 0 and \
                        (args.recog_fwd_bwd_attention or args.recog_reverse_lm_rescoring):
                    # Load a LM conf file
                    conf_lm = load_config(os.path.join(args.recog_lm_bwd, 'conf.yml'))

                    # Merge conf with args
                    args_lm_bwd = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm_bwd, k, v)

                    # Load the pre-trianed LM
                    if args_lm_bwd.lm_type == 'gated_cnn':
                        lm_bwd = GatedConvLM(args_lm_bwd)
                    else:
                        lm_bwd = RNNLM(args_lm_bwd)
                    lm_bwd, _ = load_checkpoint(lm_bwd, args.recog_lm_bwd)
                    model.lm_bwd = lm_bwd

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog metric: %s' % args.recog_metric)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('LM path: %s' % args.recog_lm)
            logger.info('LM path (bwd): %s' % args.recog_lm_bwd)
            logger.info('LM weight: %.3f' % args.recog_lm_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention)
            logger.info('reverse LM rescoring: %s' % args.recog_reverse_lm_rescoring)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache type: %s' % (args.recog_cache_type))
            logger.info('cache word frequency threshold: %s' % (args.recog_cache_word_freq))
            logger.info('cache theta (speech): %.3f' % (args.recog_cache_theta_speech))
            logger.info('cache lambda (speech): %.3f' % (args.recog_cache_lambda_speech))
            logger.info('cache theta (lm): %.3f' % (args.recog_cache_theta_lm))
            logger.info('cache lambda (lm): %.3f' % (args.recog_cache_lambda_lm))

            # GPU setting
            model.cuda()
            # TODO(hirofumi): move this

        save_path = mkdir_join(args.recog_dir, 'att_weights')
        if args.recog_n_caches > 0:
            save_path_cache = mkdir_join(args.recog_dir, 'cache')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)
            if args.recog_n_caches > 0:
                shutil.rmtree(save_path_cache)
                os.mkdir(save_path_cache)

        while True:
            batch, is_new_epoch = dataset.next(recog_params['recog_batch_size'])
            best_hyps_id, aws, (cache_attn_hist, cache_id_hist) = model.decode(
                batch['xs'], recog_params, dataset.idx2token[0],
                exclude_eos=False,
                refs_id=batch['ys'],
                ensemble_models=ensemble_models[1:] if len(ensemble_models) > 1 else [],
                speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'])

            if model.bwd_weight > 0.5:
                # Reverse the order
                best_hyps_id = [hyp[::-1] for hyp in best_hyps_id]
                aws = [aw[::-1] for aw in aws]

            for b in range(len(batch['xs'])):
                tokens = dataset.idx2token[0](best_hyps_id[b], return_list=True)
                spk = batch['speakers'][b]

                plot_attention_weights(
                    aws[b][:len(tokens)],
                    tokens,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim] if args.input_type == 'speech' else None,
                    save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                if args.recog_n_caches > 0 and cache_id_hist is not None and cache_attn_hist is not None:
                    n_keys, n_queries = cache_attn_hist[0].shape
                    # mask = np.ones((n_keys, n_queries))
                    # for i in range(n_queries):
                    #     mask[:n_keys - i, -(i + 1)] = 0
                    mask = np.zeros((n_keys, n_queries))

                    plot_cache_weights(
                        cache_attn_hist[0],
                        keys=dataset.idx2token[0](cache_id_hist[-1], return_list=True),  # fifo
                        # keys=dataset.idx2token[0](cache_id_hist, return_list=True),  # dict
                        queries=tokens,
                        save_path=mkdir_join(save_path_cache, spk, batch['utt_ids'][b] + '.png'),
                        figsize=(40, 16),
                        mask=mask)

                if model.bwd_weight > 0.5:
                    hyp = ' '.join(tokens[::-1])
                else:
                    hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Ejemplo n.º 13
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'plot.log'),
                        key='decoding')

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          batch_size=args.recog_batch_size,
                          bptt=args.bptt,
                          serialize=args.serialize,
                          is_test=True)

        if i == 0:
            # Load the LM
            if args.lm_type == 'gated_cnn':
                model = GatedConvLM(args)
            else:
                model = RNNLM(args)
            epoch = model.load_checkpoint(args.recog_model[0])['epoch']
            model.save_path = dir_name

            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            # logger.info('recog unit: %s' % args.recog_unit)
            # logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('BPTT: %d' % (args.bptt))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache theta: %.3f' % (args.recog_cache_theta))
            logger.info('cache lambda: %.3f' % (args.recog_cache_lambda))
            model.cache_theta = args.recog_cache_theta
            model.cache_lambda = args.recog_cache_lambda

            # GPU setting
            model.cuda()

        assert args.recog_n_caches > 0
        save_path = mkdir_join(args.recog_dir, 'cache')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        if args.unit == 'word':
            idx2token = dataset.idx2word
        elif args.unit == 'wp':
            idx2token = dataset.idx2wp
        elif args.unit == 'char':
            idx2token = dataset.idx2char
        elif args.unit == 'phone':
            idx2token = dataset.idx2phone
        else:
            raise NotImplementedError(args.unit)

        hidden = None
        fig_count = 0
        toknen_count = 0
        n_tokens = args.recog_n_caches
        while True:
            ys, is_new_epoch = dataset.next()

            for t in range(ys.shape[1] - 1):
                loss, hidden = model(ys[:, t:t + 2],
                                     hidden,
                                     is_eval=True,
                                     n_caches=args.recog_n_caches)[:2]

                if len(model.cache_attn) > 0:
                    if toknen_count == n_tokens:
                        tokens_keys = idx2token(
                            model.cache_ids[:args.recog_n_caches],
                            return_list=True)
                        tokens_query = idx2token(model.cache_ids[-n_tokens:],
                                                 return_list=True)

                        # Slide attention matrix
                        n_keys = len(tokens_keys)
                        n_queries = len(tokens_query)
                        cache_probs = np.zeros(
                            (n_keys, n_queries))  # `[n_keys, n_queries]`
                        mask = np.zeros((n_keys, n_queries))
                        for i, aw in enumerate(model.cache_attn[-n_tokens:]):
                            cache_probs[:(n_keys - n_queries + i + 1),
                                        i] = aw[0,
                                                -(n_keys - n_queries + i + 1):]
                            mask[(n_keys - n_queries + i + 1):, i] = 1

                        plot_cache_weights(cache_probs,
                                           keys=tokens_keys,
                                           queries=tokens_query,
                                           save_path=mkdir_join(
                                               save_path,
                                               str(fig_count) + '.png'),
                                           figsize=(40, 16),
                                           mask=mask)
                        toknen_count = 0
                        fig_count += 1
                    else:
                        toknen_count += 1

            if is_new_epoch:
                break